刚刚学习机器学习相关的知识,做一个笔记,方便以后工作中使用。
package com.tanhua.spark.mongo;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import scala.Tuple2;
public class MLlibRecommend {
public MatrixFactorizationModel bestModel(JavaPairRDD<Long, Rating> ratings){
//统计有用户数量和动态数量以及用户对动态的评分数目
Long numRatings = ratings.count();
Long numUsers = ratings.map(v1 -> (v1._2()).user()).distinct().count();
Long numMovies = ratings.map(v1 -> (v1._2()).product()).distinct().count();
System.out.println("用户:" + numUsers + "动态:" + numMovies + "评论:" + numRatings);
//将样本评分表以key值切分成3个部分,分别用于训练 (60%,并加入用户评分), 校验 (20%), and 测试 (20%)
//该数据在计算过程中要多次应用到,所以cache到内存
Integer numPartitions = 4; // 分区数
// 训练集
JavaRDD<Rating> training = ratings
.filter(v -> v._1() < 6)
.values()
.repartition(numPartitions)
.cache();
// 校验集
JavaRDD<Rating> validation = ratings
.filter(v -> v._1() >= 6 && v._1() < 8)
.values()
.repartition(numPartitions).cache();
// 测试集
JavaRDD<Rating> test = ratings
.filter(v -> v._1() >= 8)
.values()
.cache();
Long numTraining = training.count();
Long numValidation = validation.count();
Long numTest = test.count();
System.out.println("训练集:" + numTraining + " 校验集:" + numValidation + " 测试集:" + numTest);
//训练不同参数下的模型,并在校验集中验证,获取最佳参数下的模
int[] ranks = new int[]{10, 11, 12};
// double[] lambdas = new double[]{0.01, 0.03, 0.1, 0.3, 1, 3};
double[] lambdas = new double[]{0.01};
// int[] numIters = new int[]{8, 9, 10, 11, 12, 13, 14, 15};
int[] numIters = new int[]{8, 9, 10};
MatrixFactorizationModel bestModel = null;
double bestValidationRmse = Double.MAX_VALUE;
int bestRank = 0;
double bestLambda = -0.01;
int bestNumIter = 0;
for (int rank : ranks) {
for (int numIter : numIters) {
for (double lambda : lambdas) {
MatrixFactorizationModel model = ALS.train(training.rdd(), rank, numIter, lambda);
Double validationRmse = computeRmse(model, validation, numValidation);
System.out.println("RMSE(校验集) = " + validationRmse + ", rank = " + rank + ", lambda = " + lambda + ", numIter = " + numIter);
if (validationRmse < bestValidationRmse) {
bestModel = model;
bestValidationRmse = validationRmse;
bestRank = rank;
bestLambda = lambda;
bestNumIter = numIter;
}
}
}
}
double testRmse = computeRmse(bestModel, test, numTest);
System.out.println("测试数据集在 最佳训练模型 rank = " + bestRank + ", lambda = " + bestLambda + ", numIter = " + bestNumIter + ", RMSE = " + testRmse);
// 计算均值
Double meanRating = training.union(validation).mapToDouble(v -> v.rating()).mean();
// 计算标准误差值
Double baselineRmse = Math.sqrt(test.map(v -> (meanRating - v.rating()) * (meanRating - v.rating())).reduce((v1, v2) -> (v1 + v2) / numTest));
// 计算准确率提升了多少
double improvement = (baselineRmse - testRmse) / baselineRmse * 100;
System.out.println("最佳训练模型的准确率提升了:" + String.format("%.2f", improvement) + "%.");
// 构建最佳训练模型
bestModel = ALS.train(ratings.values().rdd(), bestRank, bestNumIter, bestLambda);
return bestModel;
}
/**
* 校验集预测数据和实际数据之间的均方根误差
**/
public Double computeRmse(MatrixFactorizationModel model, JavaRDD<Rating> data, Long n) {
// 进行预测
JavaRDD<Rating> predictions = model.predict(data.mapToPair(v -> new Tuple2<>(v.user(), v.product())));
JavaRDD<Tuple2<Double, Double>> predictionsAndRatings = predictions
.mapToPair(v -> new Tuple2<>(new Tuple2<>(v.user(), v.product()), v.rating()))
.join(data.mapToPair(v -> new Tuple2<>(new Tuple2<>(v.user(), v.product()), v.rating()))).values();
Double reduce = predictionsAndRatings.map(v -> (v._1 - v._2) * (v._1 - v._2))
.reduce((v1, v2) -> (v1 + v2) / n);
//正平方根
return Math.sqrt(reduce);
}
}
根据自己实际业务进行数据构建,例子如下:
比如每次点赞发圈点击喜欢,都将保存发送消息到rockermq,
推荐系统消费消息并说数字整理(根据每个操作进行加减分)保存到mongoDB
通俗的讲,rocket 是埋点数据,mongoDB是符合构建模型的数据,
实际工作:根据用户提交的工单,为他推荐一些解决方案的文章。
现在的想法是,根据工单的标题描述和解决方案的标题,文本进行匹配,进行将解决方案进行评分操作,
比如 用户1 工单1与解决方案1的匹配度80 给这个解决方案 进行8分,去构建模型
/**
* 发送消息
*
* @param type 1-发动态,2-浏览动态, 3-点赞, 4-喜欢, 5-评论,6-取消点赞,7-取消喜欢
* @param publishId
*/
private void sendMsg(Integer type, String publishId) {
try {
User user = UserThreadLocal.get();
Publish publish = this.quanZiApi.queryPublishById(publishId);
Map<String, Object> msg = new HashMap<>();
msg.put("type", type);
msg.put("publishId", publishId);
msg.put("date", System.currentTimeMillis());
msg.put("userId", user.getId());
msg.put("pid", publish.getPid());
this.rocketMQTemplate.convertAndSend("tanhua-quanzi", msg);
} catch (Exception e) {
LOGGER.error("圈子消息发送失效! type = " + type + ", publishId = " + publishId, e);
}
}
/**
* 发送小视频操作相关的消息
*
* @param videoId
* @param type 1-发动态,2-点赞, 3-取消点赞,4-评论
* @return
*/
private Boolean sendMsg(String videoId, Integer type) {
try {
User user = UserThreadLocal.get();
Video video = this.videoApi.queryVideoById(videoId);
//构建消息
Map<String, Object> msg = new HashMap<>();
msg.put("userId", user.getId());
msg.put("date", System.currentTimeMillis());
msg.put("videoId", videoId);
msg.put("vid", video.getVid());
msg.put("type", type);
this.rocketMQTemplate.convertAndSend("tanhua-video", msg);
} catch (Exception e) {
LOGGER.error("发送消息失败! videoId = " + videoId + ", type = " + type, e);
return false;
}
return true;
}
接收消息
package com.tanhua.recommend.msg;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.tanhua.dubbo.server.pojo.Publish;
import com.tanhua.recommend.pojo.RecommendQuanZi;
import org.apache.commons.lang3.StringUtils;
import org.apache.rocketmq.spring.annotation.RocketMQMessageListener;
import org.apache.rocketmq.spring.core.RocketMQListener;
import org.bson.types.ObjectId;
import org.joda.time.DateTime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.io.IOException;
@Component
@RocketMQMessageListener(topic = "tanhua-quanzi",
consumerGroup = "tanhua-quanzi-consumer")
public class QuanZiMsgConsumer implements RocketMQListener<String> {
private static final Logger LOGGER = LoggerFactory.getLogger(QuanZiMsgConsumer.class);
private static final ObjectMapper MAPPER = new ObjectMapper();
@Autowired
private MongoTemplate mongoTemplate;
@Override
public void onMessage(String msg) {
try {
JsonNode jsonNode = MAPPER.readTree(msg);
int type = jsonNode.get("type").asInt();
String publishId = jsonNode.get("publishId").asText();
Long date = jsonNode.get("date").asLong();
Long userId = jsonNode.get("userId").asLong();
Long pid = jsonNode.get("pid").asLong();
RecommendQuanZi recommendQuanZi = new RecommendQuanZi();
recommendQuanZi.setPublishId(pid);
recommendQuanZi.setDate(date);
recommendQuanZi.setId(ObjectId.get());
recommendQuanZi.setUserId(userId);
//1-发动态,2-浏览动态, 3-点赞, 4-喜欢, 5-评论,6-取消点赞,7-取消喜欢
switch (type) {
case 1: {
int score = 0;
Publish publish = this.mongoTemplate.findById(new ObjectId(publishId), Publish.class);
int length = StringUtils.length(publish.getText());
if(length > 0 && length <= 50){
score = 1;
}else if(length > 50 && length <= 100){
score = 2;
}else{
score = 3;
}
if(!CollectionUtils.isEmpty(publish.getMedias())){
score += publish.getMedias().size();
}
recommendQuanZi.setScore(Double.valueOf(score));
break;
}
case 2: {
recommendQuanZi.setScore(1d);
break;
}
case 3: {
recommendQuanZi.setScore(5d);
break;
}
case 4: {
recommendQuanZi.setScore(8d);
break;
}
case 5: {
recommendQuanZi.setScore(10d);
break;
}
case 6: {
recommendQuanZi.setScore(-5d);
break;
}
case 7: {
recommendQuanZi.setScore(-8d);
break;
}
default: {
recommendQuanZi.setScore(0d);
break;
}
}
// 将数据写入到MongoDB
String collectName = "recommend_quanzi_" + new DateTime().toString("yyyyMMdd");
this.mongoTemplate.save(recommendQuanZi,collectName );
} catch (Exception e) {
LOGGER.error("消息处理失败! msg = " + msg);
}
}
}
package com.tanhua.recommend.msg;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.tanhua.recommend.pojo.RecommendVideo;
import org.apache.rocketmq.spring.annotation.RocketMQMessageListener;
import org.apache.rocketmq.spring.core.RocketMQListener;
import org.bson.types.ObjectId;
import org.joda.time.DateTime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.stereotype.Component;
@Component
@RocketMQMessageListener(topic = "tanhua-video",
consumerGroup = "tanhua-video-consumer")
public class VideoMsgConsumer implements RocketMQListener<String> {
private static final ObjectMapper MAPPER = new ObjectMapper();
private static final Logger LOGGER = LoggerFactory.getLogger(VideoMsgConsumer.class);
@Autowired
private MongoTemplate mongoTemplate;
@Override
public void onMessage(String msg) {
try {
JsonNode jsonNode = MAPPER.readTree(msg);
Long userId = jsonNode.get("userId").asLong();
Long vid = jsonNode.get("vid").asLong();
Integer type = jsonNode.get("type").asInt();
//1-发动态,2-点赞, 3-取消点赞,4-评论
RecommendVideo recommendVideo = new RecommendVideo();
recommendVideo.setUserId(userId);
recommendVideo.setId(ObjectId.get());
recommendVideo.setDate(System.currentTimeMillis());
recommendVideo.setVideoId(vid);
switch (type) {
case 1: {
recommendVideo.setScore(2d);
break;
}
case 2: {
recommendVideo.setScore(5d);
break;
}
case 3: {
recommendVideo.setScore(-5d);
break;
}
case 4: {
recommendVideo.setScore(10d);
break;
}
default: {
recommendVideo.setScore(0d);
break;
}
}
String collectionName = "recommend_video_" + new DateTime().toString("yyyyMMdd");
this.mongoTemplate.save(recommendVideo, collectionName);
} catch (Exception e) {
LOGGER.error("处理小视频消息失败~" + msg, e);
}
}
}
1.圈子类似陌生人朋友圈,某人点击,查看,评论,点赞,发动态等。
package com.tanhua.spark.mongo;
import com.mongodb.spark.MongoSpark;
import com.mongodb.spark.rdd.api.java.JavaMongoRDD;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.bson.Document;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.JedisCluster;
import scala.Tuple2;
import java.io.InputStream;
import java.util.*;
public class SparkQunaZi {
public static void main(String[] args) throws Exception {
//加载外部的配置文件,app.properties
InputStream inputStream = SparkQunaZi.class.getClassLoader().getResourceAsStream("app.properties");
Properties properties = new Properties();
properties.load(inputStream);
//构建Spark配置
SparkConf sparkConf = new SparkConf()
.setAppName("SparkQunaZi")
.setMaster("local[*]")
.set("spark.mongodb.input.uri", properties.getProperty("spark.mongodb.input.uri"));
//构建Spark上下文
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
//加载MongoDB中的数据
JavaMongoRDD<Document> rdd = MongoSpark.load(jsc);
//打印测试数据
// rdd.foreach(document -> System.out.println(document.toJson()));
//在数据中会存在,同一个用户对不同的动态(相同动态)进行操作,需要合并操作
JavaRDD<Document> values = rdd.mapToPair(document -> {
Long userId = document.getLong("userId");
Long publishId = document.getLong("publishId");
return new Tuple2<>(userId + "_" + publishId, document);
}).reduceByKey((v1, v2) -> {
double newScore = v1.getDouble("score") + v2.getDouble("score");
v1.put("score", newScore);
return v1;
}).values();
//用户列表
List<Long> userIdList = rdd.map(v1 -> v1.getLong("userId")).distinct().collect();
//数据的打印,测试
// values.foreach(document -> System.out.println(document.toJson()));
JavaPairRDD<Long, Rating> ratings = values.mapToPair(document -> {
Long date = document.getLong("date");
int userId = document.getLong("userId").intValue();
int publishId = document.getLong("publishId").intValue();
Double score = document.getDouble("score");
Rating rating = new Rating(userId, publishId, score);
return new Tuple2<>(date % 10, rating);
});
MLlibRecommend mLlibRecommend = new MLlibRecommend();
MatrixFactorizationModel bestModel = mLlibRecommend.bestModel(ratings);
//连接redis,做存储
String redisNodesStr = properties.getProperty("redis.cluster.nodes");
String[] redisNodesStrs = StringUtils.split(redisNodesStr, ',');
Set<HostAndPort> nodes = new HashSet<>();
for (String nodesStr : redisNodesStrs) {
String[] ss = StringUtils.split(nodesStr, ':');
nodes.add(new HostAndPort(ss[0], Integer.valueOf(ss[1])));
}
JedisCluster jedisCluster = new JedisCluster(nodes);
for (Long userId : userIdList) {
Rating[] recommendProducts = bestModel.recommendProducts(userId.intValue(), 20);
List<Integer> products = new ArrayList<>();
for (Rating product : recommendProducts) {
products.add(product.product());
}
String key = "QUANZI_PUBLISH_RECOMMEND_" + userId;
jedisCluster.set(key, StringUtils.join(products, ','));
}
//关闭
jedisCluster.close();
jsc.close();
}
}
2.视频推荐,同理也是点赞,分享,查看,评论等进行加分减分
package com.tanhua.spark.mongo;
import com.mongodb.spark.MongoSpark;
import com.mongodb.spark.rdd.api.java.JavaMongoRDD;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.bson.Document;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.JedisCluster;
import scala.Tuple2;
import java.io.InputStream;
import java.util.*;
public class SparkVideo {
public static void main(String[] args) throws Exception{
//读取外部的配置文件
InputStream inputStream = SparkVideo.class.getClassLoader().getResourceAsStream("app.properties");
Properties properties = new Properties();
properties.load(inputStream);
//构建Spark配置
SparkConf sparkConf = new SparkConf()
.setAppName("SparkVideo")
.setMaster("local[*]")
.set("spark.mongodb.input.uri", properties.getProperty("spark.video.mongodb.input.uri"));
//构建Spark上下文
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
//加载MongoDB中的数据
JavaMongoRDD<Document> rdd = MongoSpark.load(jsc);
// rdd.foreach(document -> System.out.println(document.toJson()));
//在数据中有同一个用户对不同的小视频进行评价,需要进行合并操作
JavaRDD<Document> values = rdd.mapToPair(document -> {
Integer user = document.getLong("userId").intValue();
Integer product = document.getLong("videoId").intValue();
return new Tuple2<>(user + "_" + product, document);
}).reduceByKey((v1, v2) -> {
Double score = v1.getDouble("score") + v2.getDouble("score");
v1.put("score", score);
return v1;
}).values();
//得到数据中的用户id集合
List<Long> userIdList = rdd.map(v1 -> v1.getLong("userId")).distinct().collect();
// values.foreach(document -> System.out.println(document.toJson()));
//按照日期对10进行取模作为key,Rating对象作为value,获取到数据用于后续的数据处理
JavaPairRDD<Long, Rating> ratings = values.mapToPair(document -> {
Integer user = document.getLong("userId").intValue();
Integer product = document.getLong("videoId").intValue();
Double score = document.getDouble("score");
Long date = document.getLong("date");
Rating rating = new Rating(user, product, score);
return new Tuple2<>(date % 10, rating);
});
//通过MLlib模型进行推荐,获取到最优的推荐模型
MLlibRecommend mLlibRecommend = new MLlibRecommend();
MatrixFactorizationModel bestModel = mLlibRecommend.bestModel(ratings);
//构建Redis环境
String redisClusterNodes = properties.getProperty("redis.cluster.nodes");
String[] redisNodes = redisClusterNodes.split(",");
Set<HostAndPort> nodes = new HashSet<>();
for (String redisNode : redisNodes) {
String[] hostAndPorts = redisNode.split(":");
nodes.add(new HostAndPort(hostAndPorts[0], Integer.valueOf(hostAndPorts[1])));
}
JedisCluster jedisCluster = new JedisCluster(nodes);
//分别对每一个用户进行推荐,推荐20个小视频信息
for (Long userId : userIdList) {
Rating[] recommendProducts = bestModel.recommendProducts(userId.intValue(), 20);
List<Integer> products = new ArrayList<>();
for (Rating rating : recommendProducts) {
products.add(rating.product());
}
//存储到redis
String key = "QUANZI_VIDEO_RECOMMEND_" + userId;
jedisCluster.set(key, StringUtils.join(products, ','));
}
//关闭连接
jedisCluster.close();
jedisCluster.close();
}
}