MLllib目前分为两个代码包:
spark.mllib 包含基于RDD的原始算法API。
spark.ml 则提供了基于DataFrames 高层次的API,可以用来构建机器学习管道。
本文用基于DataFrame的API,DataFrame结构与MySQL表基本一致,处理数据比较方便。
基于DataFrame的API,包名为:org.apache.spark.ml.*;
数据对象引用地址为:org.apache.spark.sql.*;基于JavaRdd的API,包名为: org.apache.spark.mllib.*;
MLlib指南
LogisticRegression的损失函数为:
L(w;x,y):=ln(1+exp(−y∗wT∗x))
预测函数为: f(z)=11+ez;z=wT∗x
当 f(z) 大于0.5决策函数取1,否则取0。
求解模型参数有两种方法,一种梯度下降法,另一种是 L-BFGS.
梯度下降法详细情况参看博客
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
//构造训练数据。
List dataTraining = Arrays.asList(
RowFactory.create(1.0, Vectors.dense(0.0, 1.1, 0.1)),
RowFactory.create(0.0, Vectors.dense(2.0, 1.0, -1.0)),
RowFactory.create(0.0, Vectors.dense(2.0, 1.3, 1.0)),
RowFactory.create(1.0, Vectors.dense(0.0, 1.2, -0.5))
);
StructType schema = new StructType(
new StructField[] {
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty())
}
);
Dataset training = spark.createDataFrame(dataTraining, schema);
//测试数据
List dataTest = Arrays.asList(
RowFactory.create(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
RowFactory.create(0.0, Vectors.dense(3.0, 2.0, -0.1)),
RowFactory.create(1.0, Vectors.dense(0.0, 2.2, -1.5))
);
Dataset test = spark.createDataFrame(dataTest, schema);
//新建模型
LogisticRegression lr = new LogisticRegression();
//设置参数,迭代10次,正则化系数0.01
lr.setMaxIter(10).setRegParam(0.01);
//训练模型
LogisticRegressionModel model1 = lr.fit(training);
//决策
Dataset results = model1.transform(test);
//查看模型参数:
System.out.println(
"Model was fit using parameters: " + model1.parent().extractParamMap()
);
//查看结果
results
.collectAsList()
.forEach(
row->System.out.println(
"(" + row.get(0) + ", " + row.get(1) + ") -> prediction=" + row.get(3)
)
);
Model was fit using parameters: {
logreg_10073fbf67d3-aggregationDepth: 2,
logreg_10073fbf67d3-elasticNetParam: 0.0,
logreg_10073fbf67d3-family: auto,
logreg_10073fbf67d3-featuresCol: features,
logreg_10073fbf67d3-fitIntercept: true,
logreg_10073fbf67d3-labelCol: label,
logreg_10073fbf67d3-maxIter: 10,
logreg_10073fbf67d3-predictionCol: prediction,
logreg_10073fbf67d3-probabilityCol: probability,
logreg_10073fbf67d3-rawPredictionCol: rawPrediction,
logreg_10073fbf67d3-regParam: 0.01,
logreg_10073fbf67d3-standardization: true,
logreg_10073fbf67d3-threshold: 0.5,
logreg_10073fbf67d3-tol: 1.0E-6
}
17/09/19 23:29:15 INFO CodeGenerator: Code generated in 68.426465 ms
17/09/19 23:29:15 INFO CodeGenerator: Code generated in 35.931395 ms
//注 prediction 数组表示决策结果属于{0,1}的概率。
(1.0, [-1.0,1.5,1.3]) -> prediction=[0.0013759947069214296,0.9986240052930786]
(0.0, [3.0,2.0,-0.1]) -> prediction=[0.9816604009374171,0.018339599062582968]
(1.0, [0.0,2.2,-1.5]) -> prediction=[0.0016981475578358401,0.9983018524421641]