import java.util.HashMap; import java.util.Map; import scala.Tuple2; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.RandomForest; import org.apache.spark.mllib.tree.model.RandomForestModel; import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.SparkConf; // $example off$ public class JavaRandomForestRegressionExample { public static void main(String[] args) { // $example on$ SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestRegressionExample").setMaster("local"); JavaSparkContext jsc = new JavaSparkContext(sparkConf); // Load and parse the data file. String datapath = "data/mllib/sample_libsvm_data.txt"; JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); // Split the data into training and test sets (30% held out for testing) JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); JavaRDD trainingData = splits[0]; JavaRDD testData = splits[1]; // Set parameters. // Empty categoricalFeaturesInfo indicates all features are continuous. Map categoricalFeaturesInfo = new HashMap<>(); Integer numTrees = 3; // Use more in practice. String featureSubsetStrategy = "auto"; // Let the algorithm choose. String impurity = "variance"; Integer maxDepth = 4; Integer maxBins = 32; Integer seed = 12345; // Train a RandomForest model. final RandomForestModel model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed); // Evaluate model on test instances and compute test error JavaPairRDD predictionAndLabel = testData.mapToPair(new PairFunction() { @Override public Tuple2 call(LabeledPoint p) { return new Tuple2<>(model.predict(p.features()), p.label()); } }); Double testMSE = predictionAndLabel.map(new Function, Double>() { @Override public Double call(Tuple2 pl) { Double diff = pl._1() - pl._2(); return diff * diff; } }).reduce(new Function2() { @Override public Double call(Double a, Double b) { return a + b; } }) / testData.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression forest model:\n" + model.toDebugString()); // Save and load model model.save(jsc.sc(), "target/tmp/myRandomForestRegressionModel"); RandomForestModel sameModel = RandomForestModel.load(jsc.sc(), "target/tmp/myRandomForestRegressionModel"); // $example off$ jsc.stop(); } } |