ML中的pipeline估计是参考了py的Scipy等把
1.PIPELINE的主要部分就是
val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) // Fit the pipeline to training documents. val model = pipeline.fit(training)
2.将各个计算阶段按照stages顺序,整个阶段就是依靠DF的col,设置input,output
(1).构造tokenizer阶段
val training = sqlContext.createDataFrame(Seq( (0L, "a b c d e spark", 1.0), (1L, "b d", 0.0), (2L, "spark f g h", 1.0), (3L, "hadoop mapreduce", 0.0) )).toDF("id", "text", "label") // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. val tokenizer = new Tokenizer() .setInputCol("text") .setOutputCol("words")
(2).TF阶段
val hashingTF = new HashingTF() .setNumFeatures(1000) .setInputCol(tokenizer.getOutputCol) .setOutputCol("features")
(3).lr阶段
val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(0.01)
3.我们看看pipeline.fit做了什么事情,就是如何将各个阶段连接起来的
(1).将各各阶段的不同类分开,这里先找出评估模型,就是LogisticRegression(LogisticRegression是继承Estimator)
theStages.view.zipWithIndex.foreach { case (stage, index) => stage match { case _: Estimator[_] => indexOfLastEstimator = index case _ => } }
(2).Estimator类型的执行fit,Transformer类型的执性transformer
theStages.view.zipWithIndex.foreach { case (stage, index) => if (index <= indexOfLastEstimator) { val transformer = stage match { case estimator: Estimator[_] => estimator.fit(curDataset) case t: Transformer => t case _ => throw new IllegalArgumentException( s"Do not support stage $stage of type ${stage.getClass}") } if (index < indexOfLastEstimator) { curDataset = transformer.transform(curDataset) } transformers += transformer } else { transformers += stage.asInstanceOf[Transformer] } }
(3).最后构造出PipelineModel
new PipelineModel(uid, transformers.toArray).setParent(this)