fasttext的源码阅读

最近做了fasttext的  源码阅读,分享一下心得。

1.所用数据结构:

  1)Matrix(父类)->DenseMatrix(子类)

    DenseMatrix类里面有vector data_的变量,用一个vector保存二维矩阵的信息

    初始化变量(m:int64_t, n:int64_t) 表明矩阵的维数。

    使用DenseMatrix的模型参数:wi:隐藏层 wo:输出层,fasttext就这两层参数,全部都是用DenseMatrix表示。

    模型的输出output变量也是用DenseMatrix定义。

  2)Vector

      Vector类里面也是vector data_的变量,与DenseMatrix不同的是,它是用vector保存一维矩阵的信息

      初始化变量(m:int64_t) 表明矩阵的维数

      使用Vector的模型参数: 模型的状态类State中的变量:包括:hidden,output,grad

  3)另外的变量:

      除了上面提到的wi,wo,hidden,output, graed; input是由一个vector构造的

2. 封装的类:

  1)fasttext

  2)Model

  3)Loss

  分别描述:

  1)fasttext: fasttext类提供整个模型训练、预测的入口。其内部变量是模型训练过程中所有参数。

    1.模型参数model_ 2. 训练参数 args_ 3. 词典 dict_, 4 模型输入 input_ 5. 模型输出 output. 6. loss_

   源码如下:

    class FastText {

protected:

  std::shared_ptr args_;

  std::shared_ptr dict_;

  std::shared_ptr input_;

  std::shared_ptr output_;

  std::shared_ptr model_;

  std::atomic tokenCount_{};

  std::atomic loss_{};


    fasttext中共有四种类型的内部函数:

    1. 词典生成及序列转换函数:getInputMatrix, getOutputMatrix, getDictionary等

    2. 训练函数:cbow,skip,supervise,其中,cbow,skip是训练词向量的, surpervise是训练分类的

    3. 预测和实验: test, predict

    4. 保存和加载:保存模型,加载模型 saveModel loadModel

  2)Model:Model类提供Model训练、预测的方法,隐藏层的计算ComputeHidden, predict,update。

      其中,内部变量包括:模型状态变量hidden,output,grad,lossvalue等, wi_(第一层模型参数),wo_(第二层模型参数)

      并且内部定义了一个loss对象。

     源码举例如下:

void predict(

      const std::vector& input,

      int32_t k,

      real threshold,

      Predictions& heap,

      State& state) const;

  void update(

      const std::vector& input,

      const std::vector& targets,

      int32_t targetIndex,

      real lr,

      State& state);

  void computeHidden(const std::vector& input, State& state) const;

 3)Loss:loss类

      Loss类是由model类引用并使用的。其中封装了四种loss的计算方法:

      1.OneVsAllLoss

      2.NegativeSamplingLoss (默认的loss求法)

      3.HierarchicalSoftmaxLoss (hs的求法,fasttext的创新求法:使用了哈夫曼树)

      4.SoftmaxLoss (softmax)

      另外,还有计算output的ComputOutput函数。

3. 三个模块的关系图可绘如下:


你可能感兴趣的:(fasttext的源码阅读)