scala Monad demo

trait M[A] {
  def flatMap[B](f: A => M[B]): M[B]
}

def unit[A](x: A): M[A]

*/

object MonadDemo {

  def f1(x: Int) = {
    (x + 1, x + "+1")
  }

  def f2(x: Int) = {
    (x + 2, x + "+2")
  }

  def f3(x: Int) = {
    (x + 3, x + "+3")
  }

  // ----

  def unit(x: Int) = {
    (x, "Ops: ")
  }

  def bind(t: Tuple2[Int, String], f: Int => Tuple2[Int, String]) = {
    var res = f(t._1)
    (res._1, t._2 + res._2 + ";")
  }

  def main(args: Array[String]): Unit = {
    /*
    (a,b) -> c
    a -> (b -> c)
    a -> b -> c
    */
    def add(a: Int, b: Int) = {
      a + b
    }

    def curry_add(a: Int) = {
      def add0(b: Int) = {
        a + b
      }

      add _
    }
    // add(1,3) == curry_add(1)(3)

    //     val add3 = curry_add(_)(3)
    //     add3(5)


    // ----- code example ----
    // some business code to process data

    var log = "Ops: "
    var input = 0
    var (res1, log1) = f1(input)
    log += log1 + ";"

    var (res2, log2) = f2(res1)
    log += log2 + ";"

    var (res3, log3) = f3(res2)
    log += log3 + ";"
    println(res3, log)

    // we want this f3(f2(f1(x))), but f1 return are not f2 params

    // ----------- so transform to monad -----
    println(bind(bind(bind(unit(input), f1), f2), f3) + "  - high-level function version")

    // pipeline
    //    type ft = Int => Tuple2[Int, String]
    def pipleline(e: Tuple2[Int, String], fList: List[Int => Tuple2[Int, String]]) = {
      var e1 = e
      for (f <- fList) {
        e1 = bind(e1, f)
      }
      e1
    }

    println(pipleline(unit(input), List(f1, f2, f3)) + "  - monad version")

    // ----------- map monad -----
    //    Seq(unit(input))
    //      .map(x => bind(x, f1))
    //      .map(x => bind(x, f2))
    //      .map(x => bind(x, f3))
    val resMonad = Seq(unit(input))
      .map(bind(_, f1))
      .map(bind(_, f2))
      .map(bind(_, f3))
      .head
    println(resMonad + "  - map manod version")

    // ---- process Documents demo ---
    processDocumentsDemo()

  }

  def processDocumentsDemo(): Unit = {

    val documents = Seq(Document(), Document(), Document())

    // to find word length greater than 4 in all documents
    val wordsList: ListBuffer[Word] = ListBuffer()
    documents.foreach(d => {
      if (d.sentences != null) {
        d.sentences.foreach(s => {
          /* anything business code 1*/

          if (s.words != null) {
            s.words.foreach(w => {
              if (w.length > 4) {
                wordsList += w
              }
              /* anything business code 2*/

            })

          }
        })
      }
    })
    //    println(wordsList)  //get results

    // ------- flatMap version ----------
    val wordList2: Seq[Word] = documents
      .flatMap(d =>
        d.sentences.flatMap(s => s.words.filter(w => w.length > 4))) /* anything business code 1*/
      .map(w => w) /* anything business code 2*/


    // ----- for yield version -----------
    val wordList3: Seq[Word] = for {
      d <- documents
      s <- d.sentences
      w <- s.words
      if w.length > 4
    } yield w

    // add some business code example
    val wordList4: Seq[Word] = for {
      d <- documents
      s <- d.sentences
      w <- bindWord(s, null) // any data process extract as a bind function
      if w.length > 4
    } yield {
      /* anything business code 2 */
      w
    }


  }

  def bindWord(s: Sentence, f: Sentence => Seq[Word]): Seq[Word] = {
    /* anything business code 1*/
    type func = Sentence => Seq[Word]
    f match {
      case null => s.words
      case f: func => f(s)
    }
  }
}

case class Document() {
  val sentences: Seq[Sentence] = Seq()
}

case class Sentence() {
  val words: Seq[Word] = Seq()
}

case class Word() {
  val length: Int = 6
}

你可能感兴趣的:(scala Monad demo)