Spark机器学习(java):ALS交替最小二乘算法

楔子

Spark机器学习,推荐电影,采用ALS交替最小二乘算法

Spark中ml和mllib的区别
Spark机器学习(10):ALS交替最小二乘算法

demo

import java.io.Serializable;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.zhuzi.utils.SparkUtils;

// $example off$

/**
 * @Title: JavaALS.java
 * @Package com.zhuzi.mlZ
 *          https://www.cnblogs.com/mstk/p/7208674.html
 */
public class JavaALS {
	private static Logger log = LoggerFactory.getLogger(JavaALS.class);

	public static class Rating implements Serializable {

		private static final long serialVersionUID = 1L;
		private int userId;
		private int movieId;
		private float rating;
		private long timestamp;

		public Rating() {
		}

		public Rating(int userId, int movieId, float rating, long timestamp) {
			this.userId = userId;
			this.movieId = movieId;
			this.rating = rating;
			this.timestamp = timestamp;
		}

		public int getUserId() {
			return userId;
		}

		public int getMovieId() {
			return movieId;
		}

		public float getRating() {
			return rating;
		}

		public long getTimestamp() {
			return timestamp;
		}

		public static Rating parseRating(String str) {
			String[] fields = str.split("::");
			if (fields.length != 4) {
				throw new IllegalArgumentException("每行必须是 4 fields");
			}
			int userId = Integer.parseInt(fields[0]);
			int movieId = Integer.parseInt(fields[1]);
			float rating = Float.parseFloat(fields[2]);
			long timestamp = Long.parseLong(fields[3]);
			return new Rating(userId, movieId, rating, timestamp);
		}
	}

	public static void main(String[] args) {
		SparkSession spark = SparkUtils.buildSparkSession();
		JavaRDD<Rating> ratingsRDD = spark.read().textFile(SparkUtils.getFilePath("data/mllib/als/sample_movielens_ratings.txt")).javaRDD().map(Rating::parseRating);
		ratingsRDD.cache();
		Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);
		Dataset<Row>[] splits = ratings.randomSplit(new double[] { 0.8, 0.2 });
		Dataset<Row> training = splits[0];
		Dataset<Row> test = splits[1];

		// 利用训练数据建立ALS推荐模型
		ALS als = new ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userId").setItemCol("movieId").setRatingCol("rating");
		ALSModel model = als.fit(training);

		// Evaluate the model by computing the RMSE on the test data Note we set
		// cold start strategy to 'drop' to ensure we don't get NaN evaluation
		// metrics
		model.setColdStartStrategy("drop");
		Dataset<Row> predictions = model.transform(test);

		RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating").setPredictionCol("prediction");
		Double rmse = evaluator.evaluate(predictions);
		System.out.println("Root-mean-square error = " + rmse);

		// 为每个用户生成前10个电影推荐
		Dataset<Row> userRecs = model.recommendForAllUsers(5);
		userRecs.cache();
		System.out.println(userRecs.count());

		for (Row dataset : userRecs.collectAsList()) {
			log.warn(dataset.toString());
		}
		log.warn("以上是:为每个用户生成前10个电影推荐");
		// ////////////////////////////////////////////////////////////

		// 为每部电影生成十大用户推荐
		Dataset<Row> movieRecs = model.recommendForAllItems(5);
		movieRecs.cache();
		long count = movieRecs.count();
		System.out.println(count);

		for (Row dataset : movieRecs.collectAsList()) {
			log.warn(dataset.toString());
		}
		log.warn("以上是:为每部电影生成十大用户推荐");
		// 为指定的一组用户生成前10个电影推荐
		Dataset<Row> users = ratings.select(als.getUserCol()).distinct().limit(3);

		Dataset<Row> userSubsetRecs = model.recommendForUserSubset(users, 10);

		// 为指定的电影集生成前10个用户推荐
		Dataset<Row> movies = ratings.select(als.getItemCol()).distinct().limit(3);
		Dataset<Row> movieSubSetRecs = model.recommendForItemSubset(movies, 10);

	}
}

数据

数据采用spark官方下载的 data\mllib\als 里面的数据

结果

[0,WrappedArray([28,4.2635937], [92,3.9503636], [76,3.7858698], [39,3.3879426], [2,2.999741])]
[WARN][10,WrappedArray([2,3.4163194], [53,3.241712], [25,3.0074506], [42,2.8061774], [87,2.728731])]
[WARN][20,WrappedArray([94,4.0460143], [22,3.6154926], [77,3.612016], [46,3.5636027], [88,3.4248996])]
[WARN][1,WrappedArray([8,5.209615], [55,4.414016], [39,4.104884], [68,3.8478084], [83,3.808718])]
[WARN][11,WrappedArray([23,5.2716155], [30,4.875259], [79,4.8541327], [46,4.577077], [66,4.0911093])]
[WARN][21,WrappedArray([53,5.354313], [2,4.206762], [74,3.9279747], [87,3.827361], [4,2.956087])]
[WARN][2,WrappedArray([39,5.098097], [93,5.0468936], [83,5.011127], [8,4.8800406], [63,4.520872])]
[WARN][12,WrappedArray([43,6.6539083], [46,5.720525], [35,5.1227217], [64,4.977849], [27,4.922715])]
[WARN][22,WrappedArray([30,5.1885133], [51,5.1620464], [75,4.8899145], [22,4.611838], [23,4.3775067])]
[WARN][3,WrappedArray([51,4.9186983], [77,4.49184], [80,4.014811], [18,3.9958167], [88,3.9081008])]
[WARN][13,WrappedArray([93,3.5169842], [41,3.1136718], [70,3.0926945], [92,2.9845672], [83,2.8850255])]
[WARN][23,WrappedArray([55,5.4747477], [62,5.2821107], [32,5.1058264], [49,4.683562], [48,4.589707])]
[WARN][4,WrappedArray([29,3.9669492], [83,3.8033702], [93,3.7849288], [52,3.7831538], [41,3.7319896])]
[WARN][14,WrappedArray([46,6.290769], [8,6.004507], [92,5.8273807], [52,4.9534802], [76,4.79875])]
[WARN][24,WrappedArray([30,5.1304946], [98,5.1174364], [90,5.0042753], [96,4.570305], [53,4.503151])]
[WARN][5,WrappedArray([62,4.6407185], [55,4.5774517], [32,4.542769], [49,4.09161], [68,3.8641832])]
[WARN][15,WrappedArray([46,4.871684], [90,4.0514927], [76,3.8044252], [92,3.6267335], [1,3.590583])]
[WARN][25,WrappedArray([25,4.522156], [89,4.2942066], [28,4.2211885], [39,4.097394], [44,3.761376])]
[WARN][6,WrappedArray([25,4.8586254], [58,3.9096472], [43,3.267941], [47,3.1259787], [93,3.1119614])]
[WARN][16,WrappedArray([90,4.995305], [85,4.9076405], [51,4.708634], [76,4.307944], [39,3.9640565])]
[WARN][26,WrappedArray([51,5.7545366], [22,5.566987], [94,5.0986133], [30,5.0660195], [46,4.901876])]
[WARN][7,WrappedArray([25,5.101813], [47,3.8634279], [29,3.861642], [85,3.8062382], [58,3.6149733])]
[WARN][17,WrappedArray([7,5.6873364], [77,5.4233136], [46,5.0695214], [17,4.9160004], [90,4.6762195])]
[WARN][27,WrappedArray([68,4.560275], [30,4.4277897], [49,4.311203], [88,3.8314245], [89,3.653545])]
[WARN][8,WrappedArray([29,5.076687], [53,5.0071025], [52,4.8734946], [41,4.29465], [70,4.116517])]
[WARN][18,WrappedArray([28,4.99047], [39,4.778125], [89,4.430275], [33,4.2363925], [44,3.5871475])]
[WARN][28,WrappedArray([48,5.7651014], [55,5.6089025], [18,5.2874613], [91,5.203947], [92,5.1611648])]
[WARN][9,WrappedArray([27,5.7635093], [49,4.927815], [7,4.890088], [85,4.3361597], [17,4.1737056])]
[WARN][19,WrappedArray([90,4.075626], [94,3.6695158], [98,3.4434488], [46,3.3110263], [32,3.1066573])]
[WARN][29,WrappedArray([30,5.3171983], [90,5.259403], [46,4.761197], [63,4.126344], [32,4.1161675])]
[WARN]以上是:为每个用户生成前5个电影推荐

pox

<dependency>
	<groupId>org.apache.sparkgroupId>
	<artifactId>spark-mllib_2.11artifactId>
	<version>2.4.0version>
	<scope>runtimescope>
dependency>

你可能感兴趣的:(Spark)