记录一次用spark java写文件到本地(java推荐算法)

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

import java.io.Serializable;

public class RecommendMovie {

    //创建一个得分的类并实现Serializable接口
    public static class Rating implements Serializable {
        private int userid;
        private int movieid;
        private float rating;
        private long timestamp;


        //无参构造方法
        public Rating() {
        }

        //有参构造方法
        public Rating(int userid, int movieid, float rating, long timestamp) {

            this.userid = userid;
            this.movieid = movieid;
            this.rating = rating;
            this.timestamp = timestamp;

        }

        //get方法获取userid和其他
        public int getUserid() {
            return userid;
        }

        public int getMovieid() {
            return movieid;
        }

        public float getRating() {
            return rating;
        }

        public long getTimestamp() {
            return timestamp;
        }

        //重写parse方法来将字符串str转换成Rating类型
        public static Rating parseRating(String str) {
            //将传进来的数据进行切分获取其中的四个字段
            String[] movieInfo = str.split(",");
            //如果不是四个字段就抛出错误
            if (movieInfo.length != 4) {
                throw new IllegalArgumentException("Each line must contain 4 fields");
            }
            //将4个字符串字段分别进行转换
            int userid = Integer.parseInt(movieInfo[0]);
            int movieid = Integer.parseInt(movieInfo[1]);
            float rating = Float.parseFloat(movieInfo[2]);
            long timestamp = Long.parseLong(movieInfo[3]);

            //返回一个Rating类型的类,供调用方使用
            return new Rating(userid, movieid, rating, timestamp);
        }
    }

    public static void main(String[] args) {
        //调用spark的ml包进行协同过滤推荐算法
        SparkSession spark = SparkSession.builder().master("local[*]").appName("RecommendMovie").getOrCreate();


        //将测试数据转换成javaRDD并用Rating进行封装
        JavaRDD javaRDD = spark.read().textFile("C:\\Users\\13373\\Desktop\\test.data").javaRDD().map(Rating::parseRating);

        //将类型转换成dataframe用dataFrame中的als进行计算
        Dataset dataFrame = spark.createDataFrame(javaRDD, Rating.class);

        //进行随机切分,0.8的训练数据和0.2的测试数据
        Dataset[] split = dataFrame.randomSplit(new double[]{0.8, 0.2});

        //训练数据
        Dataset training = split[0];
        //测试数据
        Dataset test = split[1];


        /**
         *获取ALS的实例,设置最大的迭代次数和最小平方差,该对象用来训练已有数据得到模型
         *
         * 即数据建模
         */
        ALS als = new ALS()
                .setMaxIter(5)//最大迭代次数
                .setRegParam(0.01)//最小平方差
                .setUserCol("userid")
                .setItemCol("movieid")
                .setRatingCol("rating");

        ALSModel fit = als.fit(training);


        /**
         * 对模型的测试评估
         */

        fit.setColdStartStrategy("drop");

        Dataset predictions  = fit.transform(test);

        /**
         * 回归测试
         * 均方根误差
         */
        RegressionEvaluator evaluator = new RegressionEvaluator()
                .setMetricName("rmse")
                .setLabelCol("rating")
                .setPredictionCol("prediction");

        double rmse = evaluator.evaluate(predictions);

        System.out.println("Root-mean-square error = "+rmse);

        //得出10个相同用户
        Dataset userCF = fit.recommendForAllUsers(10);
        //需要将dataset转换成javaRDD再进行存储工作
        userCF.toJavaRDD().coalesce(1).saveAsTextFile("C:\\Users\\13373\\Desktop\\itemCF.txt");

        //得出10个相同商品
        Dataset itemCF = fit.recommendForAllItems(10);


        spark.stop();
    }

}

pom文件:



    4.0.0

    Aiads
    morgan13
    1.0-SNAPSHOT


    
        1.8
        4.12
        5.1.38
        1.7.21
        1.2.11
        2.11.11
        2.2.0
    


    
        
        
            org.apache.spark
            spark-mllib_2.11
            2.2.0
            
        


        
        
            org.apache.spark
            spark-core_2.11
            2.2.0
        
        
        
            org.apache.spark
            spark-sql_2.11
            2.2.0
        

        
        
            org.scala-lang
            scala-library
            2.11.11
        


    


    
    
        
            
                org.apache.maven.plugins
                maven-compiler-plugin
                3.7.0
                
                    1.8
                    1.8
                
            
        
    


    
        
            aiads
            
                1.8
                1.8
                1.8
                
            
            
                
                    nexus
                    local private nexus
                    http://nexus.aiads.com/repository/maven-public
                    
                        true
                    
                    
                        true
                    
                
            
            
                
                    nexus
                    local private nexus
                    http://nexus.aiads.com/repository/maven-public
                    
                        true
                    
                    
                        true
                    
                
            
        
    


 

 

你可能感兴趣的:(记录一次用spark java写文件到本地(java推荐算法))