基于spark实现TFIDF

上一段实习的时候用spark手写了一个tfidf,下面贴上代码并和spark中的源码进行比较。
输入文本(demo):

文档1:a b c d e f g
文档2:a b c d e f 
文档3:a b c d e
文档4:a b c d
文档5:a b c 
文档6:a b 
文档7:a

输出结果:

代码分析
主要有以下几个步骤:

  1. 读取文件到JavaRDD
  2. mapToPair将每行文本映射为doc <标题 : 单词[]>中,后者为分词后的单词数组
  3. mapValues获取每个文档的词频
  4. 将文档数进行广播,用于计算idf
  5. 类似于wordCount, 先将doc中的每个文本对应的去重单词出现次数置为1,然后aggregateByKey统计每个单词出现的文档数,用对应的求idf的公式,就可以求出idf了
  6. 将表示每个词idf的RDD collect到driver,再进行广播,进行每个文档的tfIdf计算
  7. 最后写入输出文件

和spark Mllib中tf-idf实现方法的对比
源码中也是将tf计算和idf计算分隔开的,tf计算时也是用了HashMap但是使用了hash函数(hashcode取余numfeatures)将词映射到了一个int作为Key.在计算idf时每个文档使用了一个词语大小的向量来保存每个词是否出现过,累加这些向量就得到了整个数据集中每个词语出现的文档数,即IDF,再利用公式计算,不过源码中使用的是log即以e为底而不是以10为底

源码中也是用广播的形式将TF和IDF联系起来

public class GenerateTags {

    public static void main(String[] args) throws IOException{
        SparkConf conf = new SparkConf().setMaster("local").setAppName("test");
//        SparkConf conf = new SparkConf().setAppName("video-tags");
        JavaSparkContext sc = new JavaSparkContext(conf);
        System.setProperty("hadoop.home.dir", "D:\\winutils");
        JavaRDD lines = sc.textFile("C:\\Users\\YANGXIN\\Desktop\\test.txt");

        //得到每个文档标题和对应的词串
        JavaPairRDD docs = lines.mapToPair(new PairFunction() {
            @Override
            public Tuple2 call(String s) throws Exception {
                String[] doc = s.split(":");
                String title = doc[0];
                String[] words = doc[1].split(" ");
                return new Tuple2(title, words);
            }
        });

        //得到每个文档的词频
        JavaPairRDD> docTF = docs.mapValues(new Function>() {
            @Override
            public Map call(String[] strings) throws Exception {
                Map map = new HashMap();
                int sum = strings.length;
                for(String str : strings){
                    double cnt = map.containsKey(str) ? map.get(str) : 1;
                    map.put(str, cnt);
                }
                for(String str : map.keySet()){
                    map.replace(str, map.get(str) / sum);
                }
                return map;
            }
        });

        //文档数
        final Broadcast docCnt = sc.broadcast(docs.count());

        //得到每个词的idf值
        JavaPairRDD ones = docs.flatMapToPair(new PairFlatMapFunction, String, Integer>() {
            @Override
            public Iterable> call(Tuple2 stringTuple2) throws Exception {
                List> list = new ArrayList>();
                Set set = new HashSet();
                for(String str : stringTuple2._2()){
                    set.add(str);
                }
                for(String str : set){
                    list.add(new Tuple2<>(str, 1));
                }
                return list;
            }
        });

        //每个单词在多少个文档中出现了
        JavaPairRDD wordDocCnt= ones.aggregateByKey(0, new Function2() {
            @Override
            public Integer call(Integer integer, Integer integer2) throws Exception { //同partition下的处理
                return integer + integer2;
            }
        }, new Function2() {
            @Override
            public Integer call(Integer integer, Integer integer2) throws Exception { //不同partition下的处理
                return integer + integer2;
            }
        });

        JavaPairRDD wordIdf = wordDocCnt.mapValues(new Function() {
            @Override
            public Double call(Integer integer) throws Exception {
                return Math.log10((docCnt.getValue() + 1) * 1.0 / (integer + 1));  //计算逆文档频率
            }
        });

        //广播idf值,进行tf-idf计算
        Map idfs = wordIdf.collectAsMap();
        final Broadcast> idfMap = sc.broadcast(idfs);

        //计算每个文档的tf-idf向量
        JavaPairRDD> TfIdf = docTF.mapValues(new Function, TreeMap>() {
            @Override
            public TreeMap call(Map stringDoubleMap) throws Exception {
                TreeMap map = new TreeMap();
                for(Map.Entry entry : stringDoubleMap.entrySet()){
                    String word = entry.getKey();
                    Double tf = entry.getValue();
                    Double idf = idfMap.getValue().get(word);
                    map.put(tf * idf, word);
                }
                return map;
            }
        });

        TfIdf.saveAsTextFile("C:\\Users\\YANGXIN\\Desktop\\result");
    }

参考文献:
https://github.com/endymecy/spark-ml-source-analysis/blob/master/%E7%89%B9%E5%BE%81%E6%8A%BD%E5%8F%96%E5%92%8C%E8%BD%AC%E6%8D%A2/TF-IDF.md

你可能感兴趣的:(基于spark实现TFIDF)