spark 1.1 mllib中 NaiveBayes 源码阅读

代码:mllib/api/classification/NaiveBayes.scala

模型主要是三个变量, labels存储类别,pi存储各个label的prior, theta matrix存储各个词在各个类别中的条件概率。

训练部分:代码的run部分
首先是检测feature部分的值,必须是非负的。如果是伯努利分布的话,features是0,1的向量;多项式分布,features则是term frequency
map 输出 (label, features)
combineByKey 输出每个label的一些统计信息, key: label,  value:(#doc, each feature's tf in this label).输出aggregated
combineByKey 分为三个部分,createCombiner 检验每个value: features必须是非负的,输出value:(1, features), 第一个value中1是表示一个文档; mergeValue 是一个combiner合并新的value,首先是累加1,也就是累加文档的个数,其次是累加features,计算各个features的出现次数; mergeCombiner 合并两个combiner,分别累加两个value。
接下来,遍历aggregated,统计labels, pi, theta.  在这里要注意引入了润滑参数lambda。

预测部分:
将模型进行广播,对于每个data, 用map操作输出data的预测label值,map执行predict函数,
利用贝叶斯公式求解p(c|x) = p(x|c)p(c), 两边取log,logp(c|x) = logp(x|c) + log(p(c)), 输出值最大的label

对combineByKey方法不熟悉的,参见http://abshinn.github.io/python/apache-spark/2014/10/11/using-combinebykey-in-apache-spark/ 教程。

你可能感兴趣的:(源码,spark,NaiveBayes)