最近做了fasttext的 源码阅读,分享一下心得。
1.所用数据结构:
1)Matrix(父类)->DenseMatrix(子类)
DenseMatrix类里面有vector
初始化变量(m:int64_t, n:int64_t) 表明矩阵的维数。
使用DenseMatrix的模型参数:wi:隐藏层 wo:输出层,fasttext就这两层参数,全部都是用DenseMatrix表示。
模型的输出output变量也是用DenseMatrix定义。
2)Vector
Vector类里面也是vector
初始化变量(m:int64_t) 表明矩阵的维数
使用Vector的模型参数: 模型的状态类State中的变量:包括:hidden,output,grad
3)另外的变量:
除了上面提到的wi,wo,hidden,output, graed; input是由一个vector
2. 封装的类:
1)fasttext
2)Model
3)Loss
分别描述:
1)fasttext: fasttext类提供整个模型训练、预测的入口。其内部变量是模型训练过程中所有参数。
1.模型参数model_ 2. 训练参数 args_ 3. 词典 dict_, 4 模型输入 input_ 5. 模型输出 output. 6. loss_
源码如下:
class FastText {
protected:
std::shared_ptr
args_; std::shared_ptr
dict_; std::shared_ptr
input_; std::shared_ptr
output_; std::shared_ptr
model_; std::atomic
tokenCount_{}; std::atomic
loss_{};
fasttext中共有四种类型的内部函数:
1. 词典生成及序列转换函数:getInputMatrix, getOutputMatrix, getDictionary等
2. 训练函数:cbow,skip,supervise,其中,cbow,skip是训练词向量的, surpervise是训练分类的
3. 预测和实验: test, predict
4. 保存和加载:保存模型,加载模型 saveModel loadModel
2)Model:Model类提供Model训练、预测的方法,隐藏层的计算ComputeHidden, predict,update。
其中,内部变量包括:模型状态变量hidden,output,grad,lossvalue等, wi_(第一层模型参数),wo_(第二层模型参数)
并且内部定义了一个loss对象。
源码举例如下:
void predict(
const std::vector
& input, int32_t k,
real threshold,
Predictions& heap,
State& state) const;
void update(
const std::vector
& input, const std::vector
& targets, int32_t targetIndex,
real lr,
State& state);
void computeHidden(const std::vector
& input, State& state) const;
3)Loss:loss类
Loss类是由model类引用并使用的。其中封装了四种loss的计算方法:
1.OneVsAllLoss
2.NegativeSamplingLoss (默认的loss求法)
3.HierarchicalSoftmaxLoss (hs的求法,fasttext的创新求法:使用了哈夫曼树)
4.SoftmaxLoss (softmax)
另外,还有计算output的ComputOutput函数。
3. 三个模块的关系图可绘如下: