泛函编程方式其中一个特点就是普遍地使用递归算法,而且有些地方还无法避免使用递归算法。比如说flatMap就是一种推进式的递归算法,没了它就无法使用for-comprehension,那么泛函编程也就无法被称为Monadic Programming了。虽然递归算法能使代码更简洁易明,但同时又以占用堆栈(stack)方式运作。堆栈是软件程序有限资源,所以在使用递归算法对大型数据源进行运算时系统往往会出现StackOverflow错误。如果不想办法解决递归算法带来的StackOverflow问题,泛函编程模式也就失去了实际应用的意义了。
针对StackOverflow问题,Scala compiler能够对某些特别的递归算法模式进行优化:把递归算法转换成while语句运算,但只限于尾递归模式(TCE, Tail Call Elimination),我们先用例子来了解一下TCE吧:
以下是一个右折叠算法例子:
def foldR[A,B](as: List[A], b: B, f: (A,B) => B): B = as match { case Nil => b case h :: t => f(h,foldR(t,b,f)) } //> foldR: [A, B](as: List[A], b: B, f: (A, B) => B)B def add(a: Int, b: Int) = a + b //> add: (a: Int, b: Int)Int foldR((1 to 100).toList, 0, add) //> res0: Int = 5050 foldR((1 to 10000).toList, 0, add) //> java.lang.StackOverflowError
再看看左折叠:
def foldL[A,B](as: List[A], b: B, f: (B,A) => B): B = as match { case Nil => b case h :: t => foldL(t,f(b,h),f) } //> foldL: [A, B](as: List[A], b: B, f: (B, A) => B)B foldL((1 to 100000).toList, 0, add) //> res1: Int = 705082704
def foldl2[A,B](as: List[A], b: B, f: (B,A) => B): B = { var z = b var az = as while (true) { az match { case Nil => return z case x :: xs => { z = f(z, x) az = xs } } } z }
但在实际编程中,统统把递归算法编写成尾递归是不现实的。有些复杂些的算法是无法用尾递归方式来实现的,加上JVM实现TCE的能力有局限性,只能对本地(Local)尾递归进行优化。
我们先看个稍微复杂点的例子:
def even[A](as: List[A]): Boolean = as match { case Nil => true case h :: t => odd(t) } //> even: [A](as: List[A])Boolean def odd[A](as: List[A]): Boolean = as match { case Nil => false case h :: t => even(t) } //> odd: [A](as: List[A])Boolean
even((1 to 100).toList) //> res2: Boolean = true even((1 to 101).toList) //> res3: Boolean = false odd((1 to 100).toList) //> res4: Boolean = false odd((1 to 101).toList) //> res5: Boolean = true even((1 to 10000).toList) //> java.lang.StackOverflowError
我们可以通过设计一种数据结构实现以heap交换stack。Trampoline正是专门为解决StackOverflow问题而设计的数据结构:
trait Trampoline[+A] { final def runT: A = this match { case Done(a) => a case More(k) => k().runT } } case class Done[+A](a: A) extends Trampoline[A] case class More[+A](k: () => Trampoline[A]) extends Trampoline[A]
有了Trampoline我们可以把even,odd的函数类型换成Trampoline:
def even[A](as: List[A]): Trampoline[Boolean] = as match { case Nil => Done(true) case h :: t => More(() => odd(t)) } //> even: [A](as: List[A])ch13.ex1.Trampoline[Boolean] def odd[A](as: List[A]): Trampoline[Boolean] = as match { case Nil => Done(false) case h :: t => More(() => even(t)) } //> odd: [A](as: List[A])ch13.ex1.Trampoline[Boolean]
even((1 to 10000).toList).runT //> res6: Boolean = true even((1 to 10001).toList).runT //> res7: Boolean = false odd((1 to 10000).toList).runT //> res8: Boolean = false odd((1 to 10001).toList).runT //> res9: Boolean = true
这次我们不但得到了正确结果而且也没有发生StackOverflow错误。就这么简单?
我们再从一个比较实际复杂一点的例子分析。在这个例子中我们遍历一个List并维持一个状态。我们首先需要State类型:
case class State[S,+A](runS: S => (A,S)) { import State._ def flatMap[B](f: A => State[S,B]): State[S,B] = State[S,B] { s => { val (a1,s1) = runS(s) f(a1) runS s1 } } def map[B](f: A => B): State[S,B] = flatMap( a => unit(f(a))) } object State { def unit[S,A](a: A) = State[S,A] { s => (a,s) } def getState[S]: State[S,S] = State[S,S] { s => (s,s) } def setState[S](s: S): State[S,Unit] = State[S,Unit] { _ => ((),s)} }
def zip[A](as: List[A]): List[(A,Int)] = { as.foldLeft( unit[Int,List[(A,Int)]](List()))( (acc,a) => for { xs <- acc n <- getState[Int] _ <- setState[Int](n + 1) } yield (a,n) :: xs ).runS(0)._1.reverse } //> zip: [A](as: List[A])List[(A, Int)]
zip((1 to 10).toList) //> res0: List[(Int, Int)] = List((1,0), (2,1), (3,2), (4,3), (5,4), (6,5), (7,6 //| ), (8,7), (9,8), (10,9))
zip((1 to 10000).toList) //> java.lang.StackOverflowError
按理来说foldLeft是尾递归的,怎么StackOverflow出现了。这是因为State组件flatMap是一种递归算法,也会导致StackOverflow。那么我们该如何改善呢?我们是不是像上面那样把State转换动作的结果类型改成Trampoline就行了呢?
case class State[S,A](runS: S => Trampoline[(A,S)]) { def flatMap[B](f: A => State[S,B]): State[S,B] = State[S,B] { s => More(() => { val (a1,s1) = runS(s).runT More(() => f(a1) runS s1) }) } def map[B](f: A => B): State[S,B] = flatMap( a => unit(f(a))) } object State { def unit[S,A](a: A) = State[S,A] { s => Done((a,s)) } def getState[S]: State[S,S] = State[S,S] { s => Done((s,s)) } def setState[S](s: S): State[S,Unit] = State[S,Unit] { _ => Done(((),s))} } trait Trampoline[+A] { final def runT: A = this match { case Done(a) => a case More(k) => k().runT } } case class Done[+A](a: A) extends Trampoline[A] case class More[+A](k: () => Trampoline[A]) extends Trampoline[A] def zip[A](as: List[A]): List[(A,Int)] = { as.foldLeft( unit[Int,List[(A,Int)]](List()))( (acc,a) => for { xs <- acc n <- getState[Int] _ <- setState[Int](n + 1) } yield (a,n) :: xs ).runS(0).runT._1.reverse } //> zip: [A](as: List[A])List[(A, Int)] zip((1 to 10).toList) //> res0: List[(Int, Int)] = List((1,0), (2,1), (3,2), (4,3), (5,4), (6,5), (7, //| 6), (8,7), (9,8), (10,9))
zip((1 to 10000).toList) //> java.lang.StackOverflowError
trait Trampoline[+A] { final def runT: A = this match { case Done(a) => a case More(k) => k().runT } def flatMap[B](f: A => Trampoline[B]): Trampoline[B] = { this match { case Done(a) => f(a) case More(k) => f(runT) } } } case class Done[+A](a: A) extends Trampoline[A] case class More[+A](k: () => Trampoline[A]) extends Trampoline[A]
case class State[S,A](runS: S => Trampoline[(A,S)]) { def flatMap[B](f: A => State[S,B]): State[S,B] = State[S,B] { s => More(() => { // val (a1,s1) = runS(s).runT // More(() => f(a1) runS s1) runS(s) flatMap { // runS(s) >>> Trampoline case (a1,s1) => More(() => f(a1) runS s1) } }) } def map[B](f: A => B): State[S,B] = flatMap( a => unit(f(a))) }
case class Done[+A](a: A) extends Trampoline[A] case class More[+A](k: () => Trampoline[A]) extends Trampoline[A] case class FlatMap[A,B](sub: Trampoline[A], k: A => Trampoline[B]) extends Trampoline[B]
trait Trampoline[+A] { final def runT: A = resume match { case Right(a) => a case Left(k) => k().runT } def flatMap[B](f: A => Trampoline[B]): Trampoline[B] = { this match { // case Done(a) => f(a) // case More(k) => f(runT) case FlatMap(a,g) => FlatMap(a, (x: Any) => g(x) flatMap f) case x => FlatMap(x, f) } } def map[B](f: A => B) = flatMap(a => Done(f(a))) def resume: Either[() => Trampoline[A], A] = this match { case Done(a) => Right(a) case More(k) => Left(k) case FlatMap(a,f) => a match { case Done(v) => f(v).resume case More(k) => Left(() => k() flatMap f) case FlatMap(b,g) => FlatMap(b, (x: Any) => g(x) flatMap f).resume } } } case class Done[+A](a: A) extends Trampoline[A] case class More[+A](k: () => Trampoline[A]) extends Trampoline[A] case class FlatMap[A,B](sub: Trampoline[A], k: A => Trampoline[B]) extends Trampoline[B]
在以上对Trampoline的调整里我们引用了Monad的结合特性(associativity):
FlatMap(FlatMap(b,g),f) == FlatMap(b,x => FlatMap(g(x),f)
重新右结合后我们可以用FlatMap正确表达复数步骤的运算了。
现在再试着运行zip:
def zip[A](as: List[A]): List[(A,Int)] = { as.foldLeft( unit[Int,List[(A,Int)]](List()))( (acc,a) => for { xs <- acc n <- getState[Int] _ <- setState[Int](n + 1) } yield (a,n) :: xs ).runS(0).runT._1.reverse } //> zip: [A](as: List[A])List[(A, Int)] zip((1 to 10000).toList) //> res0: List[(Int, Int)] = List((1,0), (2,1), (3,2), (4,3), (5,4), (6,5), (7,
实际上我们可以考虑把Trampoline当作一种通用的堆栈溢出解决方案。
我们首先可以利用Trampoline的Monad特性来调控函数引用,如下:
val x = f() val y = g(x) h(y) //以上这三步函数引用可以写成: for { x <- f() y <- g(x) z <- h(y) } yield z
举个实际例子:
implicit def step[A](a: => A): Trampoline[A] = { More(() => Done(a)) } //> step: [A](a: => A)ch13.ex1.Trampoline[A] def getNum: Double = 3 //> getNum: => Double def addOne(x: Double) = x + 1 //> addOne: (x: Double)Double def timesTwo(x: Double) = x * 2 //> timesTwo: (x: Double)Double (for { x <- getNum y <- addOne(x) z <- timesTwo(y) } yield z).runT //> res6: Double = 8.0
def fib(n: Int): Trampoline[Int] = { if (n <= 1) Done(n) else for { x <- More(() => fib(n-1)) y <- More(() => fib(n-2)) } yield x + y } //> fib: (n: Int)ch13.ex1.Trampoline[Int] (fib(10)).runT //> res7: Int = 55
从上面得出我们可以用flatMap来对Trampoline运算进行流程控制。另外我们还可以通过把多个Trampoline运算交叉组合来实现并行运算:
trait Trampoline[+A] { final def runT: A = resume match { case Right(a) => a case Left(k) => k().runT } def flatMap[B](f: A => Trampoline[B]): Trampoline[B] = { this match { // case Done(a) => f(a) // case More(k) => f(runT) case FlatMap(a,g) => FlatMap(a, (x: Any) => g(x) flatMap f) case x => FlatMap(x, f) } } def map[B](f: A => B) = flatMap(a => Done(f(a))) def resume: Either[() => Trampoline[A], A] = this match { case Done(a) => Right(a) case More(k) => Left(k) case FlatMap(a,f) => a match { case Done(v) => f(v).resume case More(k) => Left(() => k() flatMap f) case FlatMap(b,g) => FlatMap(b, (x: Any) => g(x) flatMap f).resume } } def zip[B](tb: Trampoline[B]): Trampoline[(A,B)] = { (this.resume, tb.resume) match { case (Right(a),Right(b)) => Done((a,b)) case (Left(f),Left(g)) => More(() => f() zip g()) case (Right(a),Left(k)) => More(() => Done(a) zip k()) case (Left(k),Right(a)) => More(() => k() zip Done(a)) } } } case class Done[+A](a: A) extends Trampoline[A] case class More[+A](k: () => Trampoline[A]) extends Trampoline[A] case class FlatMap[A,B](sub: Trampoline[A], k: A => Trampoline[B]) extends Trampoline[B]
def hello: Trampoline[Unit] = for { _ <- print("Hello ") _ <- println("World!") } yield () //> hello: => ch13.ex1.Trampoline[Unit] (hello zip hello zip hello).runT //> Hello Hello Hello World! //| World! //| World! //| res8: ((Unit, Unit), Unit) = (((),()),())