Spark-MLlib的快速使用之七(决策树-分类)

(1)数据

1,2011-01-01,1,0,1,0,0,6,0,1,0.24,0.2879,0.81,0,3,13,16

2,2011-01-01,1,0,1,1,0,6,0,1,0.22,0.2727,0.8,0,8,32,40

3,2011-01-01,1,0,1,2,0,6,0,1,0.22,0.2727,0.8,0,5,27,32

含义

instant,dteday,season,yr,mnth,holiday,weekday,workingday,weathersit,temp,atemp,hum,windspeed,casual,registered,cnt

 

(2)代码

 

public class HWDecisionTreeClass {

//【3--15】 为向量

//【16】为特征

private static class ParsePoint implements Function {

private static final Pattern SPACE = Pattern.compile(",");

 

@Override

public LabeledPoint call(String line) {

String[] parts = line.split(",");

double[] v = new double[parts.length - 3];

for (int i = 0; i < parts.length - 3; i++)

v[i] = Double.parseDouble(parts[i + 2]);

return new LabeledPoint(Double.parseDouble(parts[16]), Vectors.dense(v));

}

}

 

public static void main(String[] args) {

 

SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample").setMaster("local");

JavaSparkContext jsc = new JavaSparkContext(sparkConf);

 

// 加载与解析数据

String datapath = "hour.txt";

 

JavaRDD lines = jsc.textFile(datapath);

JavaRDD traindata = lines.map(new ParsePoint());

List take = traindata.take(3);

for (LabeledPoint labeledPoint : take) {

System.out.println("----->" + labeledPoint.features());

System.out.println("----->" + labeledPoint.label());

}

// 70%的数据用于训练,30%的数据用于测试

JavaRDD[] splits = traindata.randomSplit(new double[] { 0.9, 0.1 });

// 训练数据

JavaRDD trainingData = splits[0];

// 测试数据

JavaRDD testData = splits[1];

// 设置参数 ,空的categoricalFeaturesInfo表示所有功能都是连续的。

Integer numClasses = 1900;

Map categoricalFeaturesInfo = new HashMap();

String impurity = "gini";

Integer maxDepth = 20;

Integer maxBins = 32;

// 训练DecisionTree模型进行分类。

final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,

impurity, maxDepth, maxBins);

// 使用模型进程预测,并和实际值比较

JavaPairRDD predictionAndLabel =

testData.mapToPair(new PairFunction() {

@Override

public Tuple2 call(LabeledPoint p) {

return new Tuple2(model.predict(p.features()), p.label());

}

});

System.out.println(predictionAndLabel.take(10));

Double testErr = 1.0 * predictionAndLabel.filter(new Function, Boolean>() {

@Override

public Boolean call(Tuple2 pl) {

return !pl._1().equals(pl._2());

}

}).count() / testData.count();

System.out.println("Test Error: -------------------------------------------------------------------" + testErr);

System.out.println("Learned classification tree model:\n-------------------------------------------"

+ model.toDebugString());

}

}

 

 

你可能感兴趣的:(机器学习-spark)