sparkmllib 推荐系统实现(学习)

sparkmllib 推荐系统实现

  • sparkmllib 推荐系统实现
    • 一.构建训练模型
    • 二.构建数据逻辑(圈子/视频)
    • 三.推荐逻辑实现(圈子/视频)

sparkmllib 推荐系统实现

刚刚学习机器学习相关的知识,做一个笔记,方便以后工作中使用。

一.构建训练模型

package com.tanhua.spark.mongo;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import scala.Tuple2;

public class MLlibRecommend {

    public MatrixFactorizationModel bestModel(JavaPairRDD<Long, Rating> ratings){
        //统计有用户数量和动态数量以及用户对动态的评分数目
        Long numRatings = ratings.count();
        Long numUsers = ratings.map(v1 -> (v1._2()).user()).distinct().count();
        Long numMovies = ratings.map(v1 -> (v1._2()).product()).distinct().count();
        System.out.println("用户:" + numUsers + "动态:" + numMovies + "评论:" + numRatings);

        //将样本评分表以key值切分成3个部分,分别用于训练 (60%,并加入用户评分), 校验 (20%), and 测试 (20%)
        //该数据在计算过程中要多次应用到,所以cache到内存

        Integer numPartitions = 4; // 分区数
        // 训练集
        JavaRDD<Rating> training = ratings
                .filter(v -> v._1() < 6)
                .values()
                .repartition(numPartitions)
                .cache();

        // 校验集
        JavaRDD<Rating> validation = ratings
                .filter(v -> v._1() >= 6 && v._1() < 8)
                .values()
                .repartition(numPartitions).cache();

        // 测试集
        JavaRDD<Rating> test = ratings
                .filter(v -> v._1() >= 8)
                .values()
                .cache();

        Long numTraining = training.count();
        Long numValidation = validation.count();
        Long numTest = test.count();
        System.out.println("训练集:" + numTraining + " 校验集:" + numValidation + " 测试集:" + numTest);

        //训练不同参数下的模型,并在校验集中验证,获取最佳参数下的模
        int[] ranks = new int[]{10, 11, 12};
//        double[] lambdas = new double[]{0.01, 0.03, 0.1, 0.3, 1, 3};
        double[] lambdas = new double[]{0.01};
//        int[] numIters = new int[]{8, 9, 10, 11, 12, 13, 14, 15};
        int[] numIters = new int[]{8, 9, 10};

        MatrixFactorizationModel bestModel = null;
        double bestValidationRmse = Double.MAX_VALUE;
        int bestRank = 0;
        double bestLambda = -0.01;
        int bestNumIter = 0;

        for (int rank : ranks) {
            for (int numIter : numIters) {
                for (double lambda : lambdas) {
                    MatrixFactorizationModel model = ALS.train(training.rdd(), rank, numIter, lambda);
                    Double validationRmse = computeRmse(model, validation, numValidation);
                    System.out.println("RMSE(校验集) = " + validationRmse + ", rank = " + rank + ", lambda = " + lambda + ", numIter = " + numIter);

                    if (validationRmse < bestValidationRmse) {
                        bestModel = model;
                        bestValidationRmse = validationRmse;
                        bestRank = rank;
                        bestLambda = lambda;
                        bestNumIter = numIter;
                    }
                }
            }
        }



        double testRmse = computeRmse(bestModel, test, numTest);
        System.out.println("测试数据集在 最佳训练模型 rank = " + bestRank + ", lambda = " + bestLambda + ", numIter = " + bestNumIter + ", RMSE = " + testRmse);

        // 计算均值
        Double meanRating = training.union(validation).mapToDouble(v -> v.rating()).mean();

        // 计算标准误差值
        Double baselineRmse = Math.sqrt(test.map(v -> (meanRating - v.rating()) * (meanRating - v.rating())).reduce((v1, v2) -> (v1 + v2) / numTest));

        // 计算准确率提升了多少
        double improvement = (baselineRmse - testRmse) / baselineRmse * 100;

        System.out.println("最佳训练模型的准确率提升了:" + String.format("%.2f", improvement) + "%.");


        // 构建最佳训练模型
        bestModel = ALS.train(ratings.values().rdd(), bestRank, bestNumIter, bestLambda);

        return bestModel;
    }

    /**
     * 校验集预测数据和实际数据之间的均方根误差
     **/
    public  Double computeRmse(MatrixFactorizationModel model, JavaRDD<Rating> data, Long n) {
        // 进行预测
        JavaRDD<Rating> predictions = model.predict(data.mapToPair(v -> new Tuple2<>(v.user(), v.product())));

        JavaRDD<Tuple2<Double, Double>> predictionsAndRatings = predictions
                .mapToPair(v -> new Tuple2<>(new Tuple2<>(v.user(), v.product()), v.rating()))
                .join(data.mapToPair(v -> new Tuple2<>(new Tuple2<>(v.user(), v.product()), v.rating()))).values();

        Double reduce = predictionsAndRatings.map(v -> (v._1 - v._2) * (v._1 - v._2))
                .reduce((v1, v2) -> (v1 + v2) / n);
        //正平方根
        return Math.sqrt(reduce);
    }
}

二.构建数据逻辑(圈子/视频)

	根据自己实际业务进行数据构建,例子如下:
	比如每次点赞发圈点击喜欢,都将保存发送消息到rockermq,
	推荐系统消费消息并说数字整理(根据每个操作进行加减分)保存到mongoDB
	通俗的讲,rocket 是埋点数据,mongoDB是符合构建模型的数据,
	实际工作:根据用户提交的工单,为他推荐一些解决方案的文章。
	现在的想法是,根据工单的标题描述和解决方案的标题,文本进行匹配,进行将解决方案进行评分操作,
	比如 用户1 工单1与解决方案1的匹配度80 给这个解决方案 进行8分,去构建模型

sparkmllib 推荐系统实现(学习)_第1张图片

    /**
     * 发送消息
     *
     * @param type      1-发动态,2-浏览动态, 3-点赞, 4-喜欢, 5-评论,6-取消点赞,7-取消喜欢
     * @param publishId
     */
    private void sendMsg(Integer type, String publishId) {

        try {
            User user = UserThreadLocal.get();

            Publish publish = this.quanZiApi.queryPublishById(publishId);

            Map<String, Object> msg = new HashMap<>();

            msg.put("type", type);
            msg.put("publishId", publishId);
            msg.put("date", System.currentTimeMillis());
            msg.put("userId", user.getId());
            msg.put("pid", publish.getPid());

            this.rocketMQTemplate.convertAndSend("tanhua-quanzi", msg);
        } catch (Exception e) {
            LOGGER.error("圈子消息发送失效! type = " + type + ", publishId = " + publishId, e);
        }
    }
    /**
     * 发送小视频操作相关的消息
     *
     * @param videoId
     * @param type     1-发动态,2-点赞, 3-取消点赞,4-评论
     * @return
     */
    private Boolean sendMsg(String videoId, Integer type) {
        try {
            User user = UserThreadLocal.get();

            Video video = this.videoApi.queryVideoById(videoId);

            //构建消息
            Map<String, Object> msg = new HashMap<>();
            msg.put("userId", user.getId());
            msg.put("date", System.currentTimeMillis());
            msg.put("videoId", videoId);
            msg.put("vid", video.getVid());
            msg.put("type", type);

            this.rocketMQTemplate.convertAndSend("tanhua-video", msg);
        } catch (Exception e) {
            LOGGER.error("发送消息失败! videoId = " + videoId + ", type = " + type, e);
            return false;
        }

        return true;
    }
	接收消息
package com.tanhua.recommend.msg;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.tanhua.dubbo.server.pojo.Publish;
import com.tanhua.recommend.pojo.RecommendQuanZi;
import org.apache.commons.lang3.StringUtils;
import org.apache.rocketmq.spring.annotation.RocketMQMessageListener;
import org.apache.rocketmq.spring.core.RocketMQListener;
import org.bson.types.ObjectId;
import org.joda.time.DateTime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import java.io.IOException;

@Component
@RocketMQMessageListener(topic = "tanhua-quanzi",
        consumerGroup = "tanhua-quanzi-consumer")
public class QuanZiMsgConsumer implements RocketMQListener<String> {

    private static final Logger LOGGER = LoggerFactory.getLogger(QuanZiMsgConsumer.class);

    private static final ObjectMapper MAPPER = new ObjectMapper();

    @Autowired
    private MongoTemplate mongoTemplate;

    @Override
    public void onMessage(String msg) {
        try {

            JsonNode jsonNode = MAPPER.readTree(msg);
            int type = jsonNode.get("type").asInt();
            String publishId = jsonNode.get("publishId").asText();
            Long date = jsonNode.get("date").asLong();
            Long userId = jsonNode.get("userId").asLong();
            Long pid = jsonNode.get("pid").asLong();

            RecommendQuanZi recommendQuanZi = new RecommendQuanZi();
            recommendQuanZi.setPublishId(pid);
            recommendQuanZi.setDate(date);
            recommendQuanZi.setId(ObjectId.get());
            recommendQuanZi.setUserId(userId);

            //1-发动态,2-浏览动态, 3-点赞, 4-喜欢, 5-评论,6-取消点赞,7-取消喜欢

            switch (type) {
                case 1: {
                    int score = 0;
                    Publish publish = this.mongoTemplate.findById(new ObjectId(publishId), Publish.class);
                    int length = StringUtils.length(publish.getText());

                    if(length > 0 && length <= 50){
                        score = 1;
                    }else if(length > 50 && length <= 100){
                        score = 2;
                    }else{
                        score = 3;
                    }

                    if(!CollectionUtils.isEmpty(publish.getMedias())){
                        score += publish.getMedias().size();
                    }

                    recommendQuanZi.setScore(Double.valueOf(score));

                    break;
                }
                case 2: {
                    recommendQuanZi.setScore(1d);
                    break;
                }
                case 3: {
                    recommendQuanZi.setScore(5d);
                    break;
                }
                case 4: {
                    recommendQuanZi.setScore(8d);
                    break;
                }
                case 5: {
                    recommendQuanZi.setScore(10d);
                    break;
                }
                case 6: {
                    recommendQuanZi.setScore(-5d);
                    break;
                }
                case 7: {
                    recommendQuanZi.setScore(-8d);
                    break;
                }
                default: {
                    recommendQuanZi.setScore(0d);
                    break;
                }

            }

            // 将数据写入到MongoDB
            String collectName = "recommend_quanzi_" + new DateTime().toString("yyyyMMdd");
            this.mongoTemplate.save(recommendQuanZi,collectName );

        } catch (Exception e) {
            LOGGER.error("消息处理失败! msg = " + msg);
        }

    }
}

package com.tanhua.recommend.msg;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.tanhua.recommend.pojo.RecommendVideo;
import org.apache.rocketmq.spring.annotation.RocketMQMessageListener;
import org.apache.rocketmq.spring.core.RocketMQListener;
import org.bson.types.ObjectId;
import org.joda.time.DateTime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.stereotype.Component;

@Component
@RocketMQMessageListener(topic = "tanhua-video",
        consumerGroup = "tanhua-video-consumer")
public class VideoMsgConsumer implements RocketMQListener<String> {

    private static final ObjectMapper MAPPER = new ObjectMapper();

    private static final Logger LOGGER = LoggerFactory.getLogger(VideoMsgConsumer.class);

    @Autowired
    private MongoTemplate mongoTemplate;

    @Override
    public void onMessage(String msg) {
        try {
            JsonNode jsonNode = MAPPER.readTree(msg);

            Long userId = jsonNode.get("userId").asLong();
            Long vid = jsonNode.get("vid").asLong();
            Integer type = jsonNode.get("type").asInt();

            //1-发动态,2-点赞, 3-取消点赞,4-评论
            RecommendVideo recommendVideo = new RecommendVideo();
            recommendVideo.setUserId(userId);
            recommendVideo.setId(ObjectId.get());
            recommendVideo.setDate(System.currentTimeMillis());
            recommendVideo.setVideoId(vid);

            switch (type) {
                case 1: {
                    recommendVideo.setScore(2d);
                    break;
                }
                case 2: {
                    recommendVideo.setScore(5d);
                    break;
                }
                case 3: {
                    recommendVideo.setScore(-5d);
                    break;
                }
                case 4: {
                    recommendVideo.setScore(10d);
                    break;
                }
                default: {
                    recommendVideo.setScore(0d);
                    break;
                }
            }

            String collectionName = "recommend_video_" + new DateTime().toString("yyyyMMdd");
            this.mongoTemplate.save(recommendVideo, collectionName);

        } catch (Exception e) {
            LOGGER.error("处理小视频消息失败~" + msg, e);
        }
    }
}

三.推荐逻辑实现(圈子/视频)

	1.圈子类似陌生人朋友圈,某人点击,查看,评论,点赞,发动态等。
package com.tanhua.spark.mongo;

import com.mongodb.spark.MongoSpark;
import com.mongodb.spark.rdd.api.java.JavaMongoRDD;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.bson.Document;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.JedisCluster;
import scala.Tuple2;

import java.io.InputStream;
import java.util.*;

public class SparkQunaZi {

    public static void main(String[] args) throws Exception {

        //加载外部的配置文件,app.properties
        InputStream inputStream = SparkQunaZi.class.getClassLoader().getResourceAsStream("app.properties");
        Properties properties = new Properties();
        properties.load(inputStream);

        //构建Spark配置
        SparkConf sparkConf = new SparkConf()
                .setAppName("SparkQunaZi")
                .setMaster("local[*]")
                .set("spark.mongodb.input.uri", properties.getProperty("spark.mongodb.input.uri"));

        //构建Spark上下文
        JavaSparkContext jsc = new JavaSparkContext(sparkConf);

        //加载MongoDB中的数据
        JavaMongoRDD<Document> rdd = MongoSpark.load(jsc);

        //打印测试数据
//        rdd.foreach(document -> System.out.println(document.toJson()));

        //在数据中会存在,同一个用户对不同的动态(相同动态)进行操作,需要合并操作
        JavaRDD<Document> values = rdd.mapToPair(document -> {
            Long userId = document.getLong("userId");
            Long publishId = document.getLong("publishId");
            return new Tuple2<>(userId + "_" + publishId, document);
        }).reduceByKey((v1, v2) -> {
            double newScore = v1.getDouble("score") + v2.getDouble("score");
            v1.put("score", newScore);
            return v1;
        }).values();

        //用户列表
        List<Long> userIdList = rdd.map(v1 -> v1.getLong("userId")).distinct().collect();

        //数据的打印,测试
//        values.foreach(document -> System.out.println(document.toJson()));

        JavaPairRDD<Long, Rating> ratings = values.mapToPair(document -> {
            Long date = document.getLong("date");
            int userId = document.getLong("userId").intValue();
            int publishId = document.getLong("publishId").intValue();
            Double score = document.getDouble("score");
            Rating rating = new Rating(userId, publishId, score);
            return new Tuple2<>(date % 10, rating);
        });

        MLlibRecommend mLlibRecommend = new MLlibRecommend();
        MatrixFactorizationModel bestModel = mLlibRecommend.bestModel(ratings);

        //连接redis,做存储
        String redisNodesStr = properties.getProperty("redis.cluster.nodes");
        String[] redisNodesStrs = StringUtils.split(redisNodesStr, ',');
        Set<HostAndPort> nodes = new HashSet<>();
        for (String nodesStr : redisNodesStrs) {
            String[] ss = StringUtils.split(nodesStr, ':');
            nodes.add(new HostAndPort(ss[0], Integer.valueOf(ss[1])));
        }
        JedisCluster jedisCluster = new JedisCluster(nodes);

        for (Long userId : userIdList) {
            Rating[] recommendProducts = bestModel.recommendProducts(userId.intValue(), 20);

            List<Integer> products = new ArrayList<>();

            for (Rating product : recommendProducts) {
                products.add(product.product());
            }

            String key = "QUANZI_PUBLISH_RECOMMEND_" + userId;
            jedisCluster.set(key, StringUtils.join(products, ','));
        }

        //关闭
        jedisCluster.close();
        jsc.close();


    }

}

	2.视频推荐,同理也是点赞,分享,查看,评论等进行加分减分
package com.tanhua.spark.mongo;

import com.mongodb.spark.MongoSpark;
import com.mongodb.spark.rdd.api.java.JavaMongoRDD;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.bson.Document;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.JedisCluster;
import scala.Tuple2;

import java.io.InputStream;
import java.util.*;

public class SparkVideo {

    public static void main(String[] args) throws Exception{

        //读取外部的配置文件
        InputStream inputStream = SparkVideo.class.getClassLoader().getResourceAsStream("app.properties");
        Properties properties = new Properties();
        properties.load(inputStream);

        //构建Spark配置
        SparkConf sparkConf = new SparkConf()
                .setAppName("SparkVideo")
                .setMaster("local[*]")
                .set("spark.mongodb.input.uri", properties.getProperty("spark.video.mongodb.input.uri"));

        //构建Spark上下文
        JavaSparkContext jsc = new JavaSparkContext(sparkConf);

        //加载MongoDB中的数据
        JavaMongoRDD<Document> rdd = MongoSpark.load(jsc);

//        rdd.foreach(document -> System.out.println(document.toJson()));

        //在数据中有同一个用户对不同的小视频进行评价,需要进行合并操作
        JavaRDD<Document> values = rdd.mapToPair(document -> {
            Integer user = document.getLong("userId").intValue();
            Integer product = document.getLong("videoId").intValue();
            return new Tuple2<>(user + "_" + product, document);
        }).reduceByKey((v1, v2) -> {
            Double score = v1.getDouble("score") + v2.getDouble("score");
            v1.put("score", score);
            return v1;
        }).values();

        //得到数据中的用户id集合
        List<Long> userIdList = rdd.map(v1 -> v1.getLong("userId")).distinct().collect();

//        values.foreach(document -> System.out.println(document.toJson()));

        //按照日期对10进行取模作为key,Rating对象作为value,获取到数据用于后续的数据处理
        JavaPairRDD<Long, Rating> ratings = values.mapToPair(document -> {
            Integer user = document.getLong("userId").intValue();
            Integer product = document.getLong("videoId").intValue();
            Double score = document.getDouble("score");
            Long date = document.getLong("date");
            Rating rating = new Rating(user, product, score);
            return new Tuple2<>(date % 10, rating);
        });

        //通过MLlib模型进行推荐,获取到最优的推荐模型
        MLlibRecommend mLlibRecommend = new MLlibRecommend();
        MatrixFactorizationModel bestModel = mLlibRecommend.bestModel(ratings);

        //构建Redis环境
        String redisClusterNodes = properties.getProperty("redis.cluster.nodes");
        String[] redisNodes = redisClusterNodes.split(",");
        Set<HostAndPort> nodes = new HashSet<>();
        for (String redisNode : redisNodes) {
            String[] hostAndPorts = redisNode.split(":");
            nodes.add(new HostAndPort(hostAndPorts[0], Integer.valueOf(hostAndPorts[1])));
        }
        JedisCluster jedisCluster = new JedisCluster(nodes);

        //分别对每一个用户进行推荐,推荐20个小视频信息
        for (Long userId : userIdList) {
            Rating[] recommendProducts = bestModel.recommendProducts(userId.intValue(), 20);

            List<Integer> products = new ArrayList<>();
            for (Rating rating : recommendProducts) {
                products.add(rating.product());
            }

            //存储到redis
            String key = "QUANZI_VIDEO_RECOMMEND_" + userId;
            jedisCluster.set(key, StringUtils.join(products, ','));
        }

        //关闭连接
        jedisCluster.close();
        jedisCluster.close();

    }
}

你可能感兴趣的:(java,学习,scala,spark)