ALS算法推荐

package zqr.com;



import breeze.optimize.linear.LinearProgram;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import scala.Tuple2;
import org.apache.spark.mllib.recommendation.Rating;

import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.util.*;

public class AlsAddFp {




    public static String k="";
    public static Map,String>map=new HashMap<>();
    public static SetNotbygoods=new HashSet();
    public static Setbygoods=new HashSet();
    public static Map,String> md=new HashMap();
    public static List d=new ArrayList();
    public static void main(String args []) {


        // 输入用户id
        System.out.println("输入用户id");
        Scanner scan=new Scanner(System.in);
        String number=scan.nextLine();


        SparkConf conf = new SparkConf().setAppName("Spark WordCount written by Java").setMaster("local");

        JavaSparkContext sc = new JavaSparkContext(conf); // 其底层实际上就是ScalaSparkContext
//
//        // 加载并解析数据
//        String path = "/usr/local/spark/data/mllib/sample_fpgrowth.txt";
//
//        JavaRDD data = sc.textFile(path);
//        data.collect().forEach(System.out::println);
//
//
//        JavaRDD> data_deal=data.map(x->{
//
//                String arr[]=x.split(" ");
//                String vaule="";
//                if(map.containsKey(arr[0])){
//                    for(int i=1;i//                        vaule = vaule +" "+arr[i];
//                    }
//                   map.put(arr[0],map.get(arr[0]) + " "+vaule.toString()) ;
//                }
//                else{
//                    for(int i=1;i//                        vaule = vaule +" "+arr[i];
//                    }
//                   map.put(arr[0],vaule.toString()) ;
//                }
//            return map;
//
//        });
//        data_deal.collect();
//        for (Map.Entry entry : map.entrySet()) {
//            //Map.entry 映射项(键-值对)  有几个方法:用上面的名字entry
//            //entry.getKey() ;entry.getValue(); entry.setValue();
//            //map.entrySet()  返回此映射中包含的映射关系的 Set视图。
//            System.out.println("----------------------->>>>"+entry.getKey()+","+entry.getValue());
//        }
//
=======================================================================================================================
//
//
//        Map> mp=new HashMap>();
//
//
//            for (Map.Entry entry : map.entrySet()) {
//                Mapcharacter=new HashMap();
//                             //Map.entry 映射项(键-值对)  有几个方法:用上面的名字entry
//                             //entry.getKey() ;entry.getValue(); entry.setValue();
//                             //map.entrySet()  返回此映射中包含的映射关系的 Set视图。
//                k=entry.getKey().toString();
//                             String []arr=entry.getValue().split(" ");
//                             for(String s:arr){
//                                 if(character.containsKey(s)&&s.length()>0){
//                                     int value=character.get(s);
//                                     character.put(s,value+1);
//                                 }else if(s!=" "&&s.length()>0){
//                                     character.put(s,1);
//                                 }
//                             }
//                             System.out.println("key= " + entry.getKey() + " and value= "
//                                             + entry.getValue());
//            mp.put(k,character);
//            }
//
//        for (Map.Entry> entry : mp.entrySet()) {
//            //Map.entry 映射项(键-值对)  有几个方法:用上面的名字entry
//            //entry.getKey() ;entry.getValue(); entry.setValue();
//            //map.entrySet()  返回此映射中包含的映射关系的 Set视图。
//            System.out.println("----------------------->>>>"+entry.getKey()+","+entry.getValue());
//            for(Map.Entry entr : entry.getValue().entrySet()) {
//                String str = entry.getKey()+","+entr.getKey()+","+entr.getValue()+"\n";
//
//                System.out.println(str);
//
//                FileOutputStream fos = null;
//                try {
//                    fos = new FileOutputStream("/home/zqr/桌面/file",true);
//                } catch (FileNotFoundException e) {
//                    e.printStackTrace();
//                }
//               //true表示在文件末尾追加
//                try {
//                    fos.write(str.getBytes());
//                } catch (IOException e) {
//                    e.printStackTrace();
//                }
//            }
//        }
//=======================================================================================================================



//=======================================================================================================================






        sc.setLogLevel("WARN");


        JavaRDD data1 = sc.textFile("/home/zqr/桌面/file");

//        JavaRDD>> ratings = data1.map(s -> {
//            Map> m=new HashMap>();
//            String[] sarray = s.split(",");
//            Map mm=new HashMap();
//            mm.put(sarray[1],Integer.parseInt(sarray[2]));
//            m.put(sarray[0],mm);
//            return m;
//        });





        JavaRDD ratings = data1.map(s -> {
            String[] sarray=s.split(",");
            int i=0;
            if(sarray[0]!=null) {
                byte[] gc = sarray[1].getBytes();
                i = (int) gc[0];
            }
            if(!sarray[0].equals(number)){
                Notbygoods.add(sarray[1]);
               // System.out.println("此处测试语句1"+sarray[1]);
            }else{
                bygoods.add(sarray[1]);
               // System.out.println("此处测试语句2"+sarray[1]);
            }

                return new Rating(Integer.parseInt(sarray[0]),
                        i,
                        Integer.parseInt(sarray[2]));

        });
ratings.collect();
//============================================================================================

        System.out.println(Notbygoods);
        System.out.println(bygoods);

        for(String f:Notbygoods){
            md.put(f,"");
        }
//2System.out.println(md.entrySet());
        for(String h:bygoods){
            if(md.containsKey(h)){
                md.remove(h);
            }
        }
        //System.out.println(md.entrySet());

        for (Map.Entry, String> en : md.entrySet()) {
        //System.out.println(en.getKey() + ":" + en.getValue());
            byte[] gc = en.getKey().getBytes();
            int k = (int) gc[0];
        d.add(k);
        }


        System.out.println("输出的list列表为:"+d);
        //=====================================================================================


/**
 *使用ALS 构建推荐模型
 */
        // 隐性因子个数
        int rank = 10;
        //迭代次数
        int numIterations = 10;
        //lambdaALS的正则化参数;
        MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01);

        System.out.println("model:"+model);

        // 评估评级数据模型
        JavaRDD da=sc.parallelize(d);
        JavaRDD, Object>> userProducts =
                da.map(r -> new Tuple2<>(Integer.parseInt(number), r));


        JavaPairRDD, Integer>, Double> predictions = JavaPairRDD.fromJavaRDD(
                model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD()
                        .map(r -> new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating()))
        );

        //===================================================================================


        System.out.println("打印predictions的值");

        //predictions.collect().forEach(System.out::println);
        //=====================================================================================
        Map,Double> to=new TreeMap,Double>();
        List list=predictions.collect();
        for(Object x : list){
            String string=x.toString();
            String []arr=string.split("[,)(]");
            String uid=arr[2];
            String mid=arr[3];
            double pfen=Double.parseDouble(arr[5]);

            if(uid.equals(number)){
                to.put(mid,pfen);
            }

        }



        //转换成list进行排序
        List, Double>> lil = new ArrayList,Double>>(to.entrySet());
        // 排序
        Collections.sort(lil, new Comparator, Double>>() {
            //根据value排序
            public int compare(Map.Entry, Double> o1,
                               Map.Entry, Double> o2) {
                double result = o1.getValue() - o2.getValue();
                if(result < 0)
                    return 1;
                else if(result == 0)
                    return 0;
                else
                    return -1;

            }
        });
        List l=new ArrayList();
        int num=0;
        for (Map.Entry, Double> entry : lil) {
            System.out.println(entry.getKey() + "  " + entry.getValue());
            l.add(entry.getKey().toString());
            num++;
            if(num==3){
                break;
            }
        }

        for(String x:l){
            int s=Integer.parseInt(x);
            System.out.println((char)s);
        }











    }
    }

你可能感兴趣的:(spark,机器学习算法)