Mahout协同过滤算法源码分析(2)--splitDataset 和parallelALS

Mahout版本:0.7,hadoop版本:1.0.4,jdk:1.7.0_25 64bit。

接上篇,此篇blog分析第(1)、(2)对应的java源码,主要是splitDataset和parallelALS。

(一)其中splitDataset对应的mahout中的源java文件是:org.apache.mahout.cf.taste.hadoop.als.DatasetSplitter.java 文件,打开这个文件,可以看到这个类是继承了AbstractJob的,所以需要覆写其run方法。run方法中含有所有的操作。

进入run方法,看到刚开始和之前的算法分析一样,都是参数的获取。然后,新建了3个job,分别是:

Job markPreferences = prepareJob(getInputPath(), markedPrefs, TextInputFormat.class, MarkPreferencesMapper.class,
        Text.class, Text.class, SequenceFileOutputFormat.class);
Job createTrainingSet = prepareJob(markedPrefs, trainingSetPath, SequenceFileInputFormat.class,
        WritePrefsMapper.class, NullWritable.class, Text.class, TextOutputFormat.class);
Job createProbeSet = prepareJob(markedPrefs, probeSetPath, SequenceFileInputFormat.class,
        WritePrefsMapper.class, NullWritable.class, Text.class, TextOutputFormat.class);
首先来看第一个Job任务,这个Job没有reducer,只有一个mapper:MarkPreferencesMapper,打开这个Mapper就可以看到这个任务的具体操作了。

这个mapper含有两个函数,其一:setup,其二:map。看setup中,首先生成了一个随机变量,然后获取traning的数据集大小范围,然后获得prob(在前篇翻译为测试数据集,但是好像这个单词的翻译不是这样的,所以这里保留这个单词好了)数据集的大小范围。在map中则是根据setup中的随机数来把每条记录进行分类:

double randomValue = random.nextDouble();
      if (randomValue <= trainingBound) {
        ctx.write(INTO_TRAINING_SET, text);
      } else if (randomValue <= probeBound) {
        ctx.write(INTO_PROBE_SET, text);
      }
当随机数小于或者等于training的范围阈值traingingBound时就把这条记录标记为T,当随机数大于traingingBound且小于或者等于probeBound(prob数据集大小范围阈值)时,把该条记录标记为P。这里的probBound不一定要是1,意思就是说不一定要使用所有提供的数据集来把它们分为T和P,还可以分为不使用的数据集。

然后来看第二、三个任务,比较这两个任务,可以看到它们的不同之处只是在输入路径和输出路径,以及一些参数不同而已。而且也只是使用mapper,并没有使用reducer,那么打开WritePrefsMapper来看,这个mapper同样含有setup和map函数,setup函数则主要是获取是对T还是对P来进行处理,看map函数:

if (partToUse.equals(key.toString())) {
        ctx.write(NullWritable.get(), text);
      }
map函数就是对第一个job的输出进行处理的,partToUse是T的话,那么就把这条记录输出到TraingingDataSet中(这个是第二个任务)。第三个任务同样的道理,只是提供的partToUse不同,是P而已。这样就把原始数据分出了两部分,一部分是training dataset,一部分是 prob dataset,还有一部分是不使用的dataset(这部分可有可无)。

(二)parallelALS对应的源文件是:org.apache.mahout.cf.taste.hadoop.als.ParallelALSFactorizationJob.java文件。打开这个文件,进入run方法:参数获取完毕后,本次主要分析前面三个Job,分别是itemRatings Job、userRatings Job 和averageRatings Job。

(1)首先来分析itemRatings Job,调用的语句分别是:

Job itemRatings = prepareJob(getInputPath(), pathToItemRatings(),
        TextInputFormat.class, ItemRatingVectorsMapper.class, IntWritable.class,
        VectorWritable.class, VectorSumReducer.class, IntWritable.class,
        VectorWritable.class, SequenceFileOutputFormat.class);
这里有mapper和reducer,先分析mapper,即ItemRatingVectorsMapper,打开这个类看到,这个mapper中就一个map函数:

String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString());
      int userID = Integer.parseInt(tokens[0]);
      int itemID = Integer.parseInt(tokens[1]);
      float rating = Float.parseFloat(tokens[2]);

      Vector ratings = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
      ratings.set(userID, rating);

      ctx.write(new IntWritable(itemID), new VectorWritable(ratings, true));
这里的操作就是把一条记录转换,然后输出<key,value>对应为 itemID, [userID:rating]这样的输出,然后到reducer,即VectorSumReducer,这个reducer中也只有一个reduce函数:

Vector vector = null;
    for (VectorWritable v : values) {
      if (vector == null) {
        vector = v.get();
      } else {
        vector.assign(v.get(), Functions.PLUS);
      }
    }
    ctx.write(key, new VectorWritable(vector));
这个VectorSumReducer在前面的算法中好像也有分析过,vector.assign(v.get(),Functions.PLUS)是把vector中对应的项相加;比如如果原来的vector为[1:2.3,2:3.3,5:3.4],然后使用上面的assign和Function.PLUS参数加上v,[2:3.3,4:4.0],那么新的vector就更新为[1:2.3,2:6.6,4:4.0,5:3.4],这就是所谓的对应相加。所以这个reducer的输出为itemID ::[userID:rating,userID:rating,...]这样的输出。参考ratings.dat文件的说明文件中说item有3952个记录,由这个job的输出结果来看只有3692条记录输出,说明training dataset中只含有3692个item。可以编写下面的测试文件来读取这个job的输出文件,看是否和设想一样:

package mahout.fansy.als.test;

import java.io.IOException;
import java.util.Map;

import org.apache.hadoop.io.Writable;

import mahout.fansy.utils.read.ReadArbiKV;

public class ReadItemRatings {

	/**
	 *  读取itemRatings Job的输出
	 * @param args
	 * @throws IOException 
	 */
	public static void main(String[] args) throws IOException {

		String path="hdfs://ubuntu:9000/user/mahout/temp/als/itemRatings/part-r-00000";
		Map<Writable, Writable>map=ReadArbiKV.readFromFile(path);
		System.out.println("read "+map.getClass().toString()+" done...");
	}

}
其中的ReadArbiKV类文件在前面的系列算法中有提到代码 Hadoop Writable深度复制及读取任意序列文件,这里就不多说了。
(2)第(1)个job的输出类似:<key,vlaue> --> <itemID,[userID:rating,userID,rating,...]> ,然后到了userRatings job,这个job的调用如下:

Job userRatings = prepareJob(pathToItemRatings(), pathToUserRatings(),
        TransposeMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class, IntWritable.class,
        VectorWritable.class);
输入是itemRatings job的输出,mapper是TransposeMapper,看这个mapper,其中的map函数源码如下:

protected void map(IntWritable r, VectorWritable v, Context ctx) throws IOException, InterruptedException {
    int row = r.get();
    Iterator<Vector.Element> it = v.get().iterateNonZero();
    while (it.hasNext()) {
      Vector.Element e = it.next();
      RandomAccessSparseVector tmp = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
      tmp.setQuick(row, e.get());
      r.set(e.index());
      ctx.write(r, new VectorWritable(tmp));
    }
  }
那么row就是itemID了,然后遍历value的值,输出是<userID,[itemID:rating]>,这个就是map的输出了,所以itemRatings job里面的一条记录就对应于这里的map的多条输出了。

看reducer,即MergeVectorsReducer,它的reduce函数更加简单:

public void reduce(WritableComparable<?> key, Iterable<VectorWritable> vectors, Context ctx)
      throws IOException, InterruptedException {
    ctx.write(key, VectorWritable.merge(vectors.iterator()));
  }
直接调用VectorWritable的merge方法,reducer把相同key的value集中起来,比如user1 :{[item1:rating1],[item2:rating2],...}然后merge方法的操作是:

public static VectorWritable merge(Iterator<VectorWritable> vectors) {
    Vector accumulator = vectors.next().get();
    while (vectors.hasNext()) {
      VectorWritable v = vectors.next();
      if (v != null) {
        Iterator<Vector.Element> nonZeroElements = v.get().iterateNonZero();
        while (nonZeroElements.hasNext()) {
          Vector.Element nonZeroElement = nonZeroElements.next();
          accumulator.setQuick(nonZeroElement.index(), nonZeroElement.get());
        }
      }
    }
    return new VectorWritable(accumulator);
  }
这里看到是把所有的item整合起来了,所以reducer的输出是 <key,value> --> <userID,[itemID:rating,itemID:rating,...]>,感觉这个和itemRatings job的输出差不多。

把前面读取itemRatings job输出的代码中的路径更改下就可以读取userRatings job的输出了,看是否和猜测的一样。这里通过terminal看到这个job的输出为6040条记录,和ratings.dat描述文件一样,dataset含有6040个用户。

(3)averageRatings job,这个任务的调用代码:

Job averageItemRatings = prepareJob(pathToItemRatings(), getTempPath("averageRatings"),
        AverageRatingMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class,
        IntWritable.class, VectorWritable.class);
这个job的输入文件同样是itemRatings job的输出,即输入文件的格式是:<key,vlaue> --> <itemID,[userID:rating,userID,rating,...]> 。看mapper,即AverageRatingMapper:

 protected void map(IntWritable r, VectorWritable v, Context ctx) throws IOException, InterruptedException {
      RunningAverage avg = new FullRunningAverage();
      Iterator<Vector.Element> elements = v.get().iterateNonZero();
      while (elements.hasNext()) {
        avg.addDatum(elements.next().get());
      }
      Vector vector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
      vector.setQuick(r.get(), avg.getAverage());
      ctx.write(new IntWritable(0), new VectorWritable(vector));
    }
  }
首先,看write中可以都是0,那么可以肯定在reducer中的输入和输出都是一条记录而已,这个可以在对job的分析界面(50030)或者terminal中可以得到求证。这里看到的代码的意思是把某个itemid的全部user的评价ratings全部遍历一遍,然后求这些ratings的平均值,然后输出就是<key,value>  -->   <0,[itemID:averageRating]>这样的输出,这里有新的类RunningAverage 和FullRunningAverage,其实这两个类可以暂时不用管的,或者,算了打开看看吧:

public synchronized void addDatum(double datum) {
    if (++count == 1) {
      average = datum;
    } else {
      average = average * (count - 1) / count + datum / count;
    }
  }
可以看到addDatum方法就是算平均值的。然后就是reducer了,reducer还是那个MergeVectorReducer,那么reducer的输出就应该是<key,value> --> <0,[itemID:averageRating,itemID:averageRating,...]>。

接下里就是initializeM和for循环了,今天又晚了。总感觉假期效率低的没法说。。。


分享,成长,快乐

转载请注明blog地址:http://blog.csdn.net/fansy1990


你可能感兴趣的:(Mahout,源码分析,协同过滤,splitDataset,parallelALS)