在scala中使用for表达式做monad运算

在haskell中,我们有语法糖‘do’帮助表达monad运算。scala中我们也有相应语法糖‘for’。

for表达式会被scala compiler做一些变换,简单的例子如下:

for {

  a <- foo

  b <- bar

} yield (a + b)

===>

foo.flatMap((a) => {

  bar.map((b) => {

     a + b

  })

})

所以我们需要实现两个方法 flatMap和map。

还是用前面的state monad作为例子, 我们给类型State加上flatMap和map。

case class State[S, A](runState: S => (S, A))(implicit m : Monad[({type M[a] = State[S, a]})#M]) {

  def map[B](f: A => B) : State[S, B] = m.bind(this, (a: A) => m.ret(f(a)))

  def flatMap[B](f: A => State[S, B]) : State[S, B] = m.bind(this, f)

}

这里我们使用了一个隐式参数,然后我们可以直接使用ret和bind。

同时加一个helper简化Monad[({type M[a] = State[S, a]})#M].ret

def ret[S, A](a: A) : State[S, A] = Monad[({type M[a] = State[S, a]})#M].ret(a)

好了,我们可以使用for表达式了,例子如下:

object Main {

  

  import StateMonad._

  

  def main(args: Array[String]) {

    val r = for {

        a <- ret[Int, Int](3)

        b <- ret[Int, Int](4)

    } yield (a+b)

    println(r.runState(1))

  }



}

你可能感兴趣的:(scala)