vlambda博客
学习文章列表

高阶函数式编程(四):在 Kotlin 中“实现”单子(Monad)

距离应用函子(Applicative)的完结已经过去了三个月,这三个月里准备了两场线上演讲(都是 KMM 的),然后工作也到了年中考核,再加上自己拖延症犯一犯也就到了现在。现在是时候把高阶函数式编程的进度再推一推了,相比函子(Functor)与应用函子(Applicative),单子(Monad)的名气更大,很多不懂函数式编程的人也都听过它的顶顶大名。单子也也是 Functor -> Applicative -> Monad 三层级概念结构的最后一层,所以等单子讨论完成之后,《高阶函数式编程》系列就到了一个阶段性的收尾了。


本文基于前置知识,前文回顾:








什么是单子(Monad)?


按照定义,单子首先应该是应用函子,而应用函子而又首先应该是函子。


在 Haskell 中 Applicative 和 Monad 的定义如下:


class (Functor f) => Applicative f where    pure :: a -> f a    (<*>) :: f (a -> b) -> f a -> f b    class Monad m where    return :: a -> m a    (>>=) :: m a -> (a -> m b) -> m b        (>>:: m a -> m b -> m b    x >> y = x >>= \_ -> y        fail :: String -> m a    fail :: msg = error msg


声明一个 Monad 并不需要对其施加 Applicative 约束。即在 Kotlin 中 Monad 类型无需继承自 Applicative 类型,《Haskell 趣学指南》第 228 页提到了这个问题,书中说 Haskell 的设计者并没有考虑到 Applicative 会这么有用,但并不影响 Monad 是一个 Applicative 的事实,所以这虽然可能是一个问题,但是对于我们来说刚好可以让我们 Monad 的实现跨过 Applicative。


前几篇文章我们已经阅读过不少 Haskell 中概念的定义,我们可以看到 Monad 的 return 函数与 Applicative 中的 pure 函数是相同的东西,那么除此之外 Monad 中最核心的函数就是 >>= 了。它将一个包裹在 Monad 中的 a 类型值从 Monad 中取出,然后将一个类型为 a -> m b 的函数应用其上,最终得到一个包裹在 Monad 中的 b 类型值(类型为 m b)。   


实现


根据 Haskell 中的声明我们可以立刻写出 Monad 的定义:


interface Monad<T, R : Monad<T, R, S>, S : Monad<*, *, S>> {    infix fun `return`(t: T): R infix fun <A> `>>=`(function: (T) -> S): S}


由于柯里化,和 Functor 的 fmap 函数与 Applicative 的<*> 函数一样,Monad 的 >>= 函数也可以表达两种含义:


含义 1:>>= 函数表示一种接受一个 m a 单子类型与一个 a -> m b 函数类型的参数的函数,该函数可以将单子 m a 拆解取出 a 类型的元素,并将 a -> m b 类型的函数应用其上,得到 m b 类型的单子作为 >>= 函数的结果返回。


含义 2:……


等等,含义 2 好像意义不大。如果我们提供一个无参版本的 `>>=` 函数,那么它的返回值就是一个接受函数类型 (T) -> S 类型参数并返回 S 类型的值的高阶函数,这个高阶函数仅仅是执行了它的参数并将其返回结果作为自己的结果返回,并没有什么用处。我们在函数组合中的 $ 函数可能与这个类似,但 $ 函数的主要作用是可以将一个嵌套的函数调用链展平并优雅的传参,但拥有含义 2 的 `>>=` 函数则没这个功能。


return 在 Kotlin 中是个关键字,不能直接作为函数名,所以我们用 `` 将其包裹后使用,对于 >>= 这种非字母字符串也是同理。


我们看到 Haskell 的声明中还有 >> 和 fail 函数,在本文我们还无需讨论他们,所以先跳过。


List 单子


List 又出现了,这次它将作为一个单子与我们见面。虽然 Monad 的定义中没有和 Applicative 有关的类型约束,但根据概念,List 单子首先应该是一个应用函子,还记得我们的 ApplicativeList 吗,回顾一下:



@Suppress("UNCHECKED_CAST")class ApplicativeList<T>(private val coreList: List<T>) : List<T> by coreList, Functor<T, ApplicativeList<T>, ApplicativeList<*>>, Applicative<T, ApplicativeList<T>, ApplicativeList<*>> {
override infix fun <A> fmap1(function: (T) -> A): (ApplicativeList<T>) -> ApplicativeList<A> = { ApplicativeList(it.map(function)) } override infix fun <A> fmap2(function: (T) -> A): ApplicativeList<*> = fmap2Default(function)
override infix fun pure(t: T): ApplicativeList<T> = ApplicativeList(listOf(t)) override infix fun <A, F> `<*>`(function: F): ApplicativeList<A> where F : Applicative<(T) -> A, F, *>, F : Functor<(T) -> A, F, *> = ApplicativeList((function as? ApplicativeList<(T) -> A>)?.flatMap { func -> map { func(it) } } ?: throw ClassCastException("Param function must be ApplicativeList Type"))
override fun <A, F> `<**>`(function: F): (ApplicativeList<T>) -> ApplicativeList<A> where F : Applicative<(T) -> A, F, *>, F : Functor<(T) -> A, F, *> = `<*Default*>`(function) as (ApplicativeList<T>) -> ApplicativeList<A>
override fun hashCode(): Int = coreList.hashCode() override fun equals(other: Any?): Boolean = coreList == other override fun toString(): String = coreList.toString()}


我们只要给 ApplicativeList 改个名然后实现 Monad 接口即可得到 MonadList 的完整实现:


class MonadList<T>(private val coreList: List<T>) : List<T> by coreList, Functor<T, MonadList<T>, MonadList<*>>, Applicative<T, MonadList<T>, MonadList<*>>, Monad<T, MonadList<T>, MonadList<*>> {
override infix fun <A> fmap1(function: (T) -> A): (MonadList<T>) -> MonadList<A> = { MonadList(it.map(function)) } override infix fun <A> fmap2(function: (T) -> A): MonadList<A> = fmap2Default(function)
override infix fun pure(t: T): MonadList<T> = MonadList(listOf(t)) override infix fun <A, F> `<*>`(function: F): MonadList<A> where F : Applicative<(T) -> A, F, *>, F : Functor<(T) -> A, F, *> = MonadList((function as? MonadList<(T) -> A>)?.flatMap { func -> map { func(it) } } ?: throw ClassCastException("Param function must be FList Type"))
override fun <A, F> `<**>`(function: F): (MonadList<T>) -> MonadList<A> where F : Applicative<(T) -> A, F, *>, F : Functor<(T) -> A, F, *> = `<*Default*>`(function) as (MonadList<T>) -> MonadList<A>
/** * 单子的函数 */    override infix fun `return`(t: T): MonadList<T> = pure(t) override infix fun <A> `>>=`(function: (T) -> MonadList<*>): MonadList<A> = (asSequence().map(function).flatten() as Sequence<A>).toMonadList()
override fun hashCode(): Int = coreList.hashCode() override fun equals(other: Any?): Boolean = other is MonadList<*> && coreList == other.coreList override fun toString(): String = coreList.toString()
}


仔细观察 `>>=` 我们会发现,我们再次触摸到了 Kotlin 编译器类型系统的天花板。`>>=` 函数的参数 function 的返回值类型与 `>>=` 函数的返回值类型不一致;它们俩在 Monad 中都使用泛型 S 来定义,但是到了具体的 Monad 实现者这里,作为 `>>=` 函数返回值的 S 可以使用比 MonadList<*> 范围更小的 MonadList<A> 来代替;为什么?因为 MonadList<A> 类型必然是 MonadList<*> 类型,类型安全,使用者在调用 `>>=` 函数的时候期望得到 MonadList<*> 类型的结果,那我们给他一个 MonadList<A> 类型的结果自然是类型安全且符合要求的,即我们可以认为类型 MonadList<A> 是 MonadList<*> 的子类型。


但到了参数 function 这里情况又有不同,原本的 function 类型是 (T) -> MonadList<*>,如果我们擅自把 MonadList<*> 改为 MonadList<A> 就会得到类型 (T) -> MonadList<A>,由于类型  MonadList<A> 是类型 MonadList<*> 的子类型,且这两者都在函数类型的返回值位置上(即 out 位置),所以函数类型 (T) -> MonadList<A> 是函数类型 (T) -> MonadList<*> 的子类型。


上面这段黑体字有点绕,如果你确定看懂了这一段黑体字表达的子类型化关系的意思再接着看下一段:


但由于 (T) -> MonadList<*> 在函数 `>>=` 的参数位置上(即 in 位置),我们如果要确保用户在调用 `>>=` 函数时类型安全,我们只能扩大参数的类型范围,即如果要改写 (T) -> MonadList<*> 类型,那也只能改写为它的父类型;因此,将参数 function 的返回值类型改写为 MonadList<A> 是违背类型安全的非法操作。


上面两段黑色文字均涉及到类型、复合类型(泛型、函数类型)的子类型化关系的推导,会有点复杂,但是如果你能理解透彻就能理解我们当前实现的 MonadList 在使用时某些位置会逃逸编译器静态类型检查的缺陷。


我们稍微扯远一点,这个问题的根源是从哪来的?根源还是在于没有高阶类型给我们带来的困扰,我们无法分别给一个泛型的本身和其内部的泛型参数加类型上界的限制,也无法用两个泛型参数分别表示这二者,导致我们不得已要在 Monad 声明的时候就声明泛型 S,用其表示内部持有 A 类型的 Monad,但由于 A 类型是函数的泛型参数,S 只得使用 * 投影来确保其可以装入任意类型的参数。


回到正题,怎么解决这个问题?把 `>>=` 函数隐藏,在 Monad 或 Monad List 内部实现强转后向用户暴露一个类型校验自洽的 API。


既然函数 `>>=` 不直接给使用者用了,干脆把这个名字给 public 的 API 吧,重新给它起一个山寨一点的名字:`>>==`,MonadList 的完整实现变成了下型这样:


class MonadList<T>(private val coreList: List<T>) : List<T> by coreList, Functor<T, MonadList<T>, MonadList<*>>, Applicative<T, MonadList<T>, MonadList<*>>, Monad<T, MonadList<T>, MonadList<*>> {
override infix fun <A> fmap1(function: (T) -> A): (MonadList<T>) -> MonadList<A> = { MonadList(it.map(function)) } override infix fun <A> fmap2(function: (T) -> A): MonadList<A> = fmap2Default(function)
override infix fun pure(t: T): MonadList<T> = MonadList(listOf(t)) override infix fun <A, F> `<*>`(function: F): MonadList<A> where F : Applicative<(T) -> A, F, *>, F : Functor<(T) -> A, F, *> = MonadList((function as? MonadList<(T) -> A>)?.flatMap { func -> map { func(it) } } ?: throw ClassCastException("Param function must be FList Type"))
override fun <A, F> `<**>`(function: F): (MonadList<T>) -> MonadList<A> where F : Applicative<(T) -> A, F, *>, F : Functor<(T) -> A, F, *> = `<*Default*>`(function) as (MonadList<T>) -> MonadList<A>
/** * 单子的函数 */ override infix fun `return`(t: T): MonadList<T> = pure(t) override infix fun <A> `>>==`(function: (T) -> MonadList<*>): MonadList<A> = (asSequence().map(function).flatten() as Sequence<A>).toMonadList()
override fun hashCode(): Int = coreList.hashCode() override fun equals(other: Any?): Boolean = other is MonadList<*> && coreList == other.coreList override fun toString(): String = coreList.toString()
}


我们再添加刚才提到的 public 扩展函数:


infix fun <A, S : Monad<A, S, *>, T, R : Monad<T, R, S>> R.`>>=`(function: (T) -> S): S = `>>==`<A>(function)


该扩展函数只是缩小了参数的范围,本质还是调用 Monad 的 `>>==` 函数。


由于 Monad 是个 interface,它的成员无法声明为 internal,所以 `>>==` 函数还是会暴露给使用者,不过我们可以在注释中注明,请用户尽量使用扩展函数 `>>=` 而获得更完善的类型安全校验。


来玩一下 MonadList 吧:


fun main() { val intList = MonadList(listOf(3, 4, 5)) val floatList = intList.`>>=`<Float, MonadList<Float>, Int, MonadList<Int>> { MonadList(listOf(it.toFloat(), -it.toFloat())) } println(floatList)}


运行结果如下:


[3.0, -3.0, 4.0, -4.0, 5.0, -5.0]


你有没有想到些什么?


你现在也许还没想到,不过我写一个下面的例子你就能想到了(注意,下面的是伪代码,假设我们调用 `>>=` 可以省略泛型参数):


floatList `>>=` ::floatToStringMonadList `>>=` ::stringToIntMonadList `>>=` intToBooleanMonadList

  

每个 Monad 调用一次 `>>=` 之后都会得到一个持有的元素类型不同的 Monad,从而形成一个 `>>=` 函数连续调用的调用链。


哈哈哈哈,如果你跟我一样是个 Android 程序员,那马上就会想到 RxJava。在 RxJava 中每次 Observable 调用一个操作符就会产生一个新的 Observable 对象(但这个 Observable 受观察的元素类型可能不同)。RxJava 只是个例子,所有的响应式流实现都应用了 Monad 的思想,包括 Kotlin 中的 Sequence、Flow 等等。


再引申一下,在 Haskell 中 IO 也是单子的实现,Haskell 是纯函数式编程语言,不允许可变状态的出现,因此 IO 需要将可变状态的世界和整个程序隔离开,IO 只能在 IO 中调用,我们是不是能联想到 Kotlin 的 suspend 函数,它需要隔离可挂起与不可挂起的世界,因此 suspend 函数只能在 suspend 函数内部调用,这么一联系,suspend 函数与 Kotlin 协程也是 Monad 思想的应用,这里我懒得写大量的 Haskell 代码,点到为止吧。


所以,提高理论上限,可以让我们造出更牛逼的东西。


验证单子定律


老规矩,要验证一个单子是单子,它必须遵守单子定律,单子定律比应用函子少一些,一共三条:


左单位元:return x >>=  f = f x


fun main() { val emptyIntList = MonadList(listOf<Int>()) val function: (Int) -> MonadList<Float> = { MonadList(listOf(it.toFloat(), -it.toFloat())) } val result = (emptyIntList `return` 8).`>>=`<Float, MonadList<Float>, Int, MonadList<Int>>(function) == function(8) println(result)}


输出:true。


右单位元:m >>=  return = m


fun main() {    val intList = MonadList(listOf(345)) val result = intList.`>>=`<Int, MonadList<Int>, Int, MonadList<Int>>(intList::`return`) == intList println(result)}


输出:true。


结合律:(m >>=  f) >>= g = m >>= (\x -> f x >>= g)


直白的说就是,如果有一条 Monad 调用链,嵌套顺序无关紧要:


fun main() { val intList = MonadList(listOf(3, 4, 5)) val intToFloatFunction: (Int) -> MonadList<Float> = { MonadList(listOf(it.toFloat(), -it.toFloat())) } val floatToStringFunction: (Float) -> MonadList<String> = { MonadList(listOf((it * 2).toString())) } val expression1 = (intList.`>>=`<Float, MonadList<Float>, Int, MonadList<Int>>(intToFloatFunction)).`>>=`<String, MonadList<String>, Float, MonadList<Float>>(floatToStringFunction) val expression2 = intList.`>>=`<String, MonadList<String>, Int, MonadList<Int>> { intToFloatFunction(it).`>>=`<String, MonadList<String>, Float, MonadList<Float>>(floatToStringFunction) } println(expression1 == expression2)}


输出:true,比前两个复杂点,不过依然轻松验证。


很快就会再见


单子能讨论的东西很多,我们这一期也是一样,先定义与实现,后面继续讨论更多的有关 Monad 的有趣话题。我保证下次更新间隔不会太长。