main.cc
main():1.train(), 2.test(),3. quantize(), 4.printWordVectors(), 5.printSentenceVectors(), 6.printNgrams(), 7.nn(), 8.analogies(), 9.predict(), 10.dump(), 11.printUsage()
1.train() :
fasttext.train()训练
fasttext.saveModel()保存模型
fasttext.saveVectors()保存词向量,第一行是word数+空格+维度dim;从第二行开始 word+vector
fasttext.train():readFromFile(), loadVectors()/nowrds()+bucket, nlabels()/nwords(), startThreads(),setTargetCounts()
readFromFile()读取文件,获取词典:
readWord():从文件中读入word,以空格、tab等作为分词依据;
add():若word已经存在于词典中,count+1,否则加入词典中;
threshold():根据词频排序,剔除阈值以下的词,构造哈希;
initTableDiscard():初始化丢弃规则;
initNgrams() :初始化Ngram
若使用预训练的词向量文件,则loadVectors()加载该文件;否则,以(字典words数+bucket数)*dim词向量维度作为输入
输出为label数*dim(分类)或者是words数*dim(词向量)。
当 args_->model == model_name::sup 时,训练分类器, 所以输出的种类是标签总数 dict_->nlabels();
否则训练的是词向量,输出种类就是词的种类 dict_->nwords()。
startThreads():
trainThread():根据线程数,将训练文件按照总字节数均分成多个部分;根据训练需求的不同,用的更新策略也不同,它们分别是:
1. supervised()2. cbow() 3. skipgram()
model.update():
computeHidden():计算前向传播:输入层 -> 隐层
loss=ns时,negativeSampling():负采样,训练时每次选择一个正样本,随机采样几个负样本,每种输出都对应一个参数向量,保存于 wo_ 的各行。对所有样本的参数更新,都是一次独立的 LR 参数更新。
loss=hs时,hierarchicalSoftmax():层次softmax() ,对于每个目标词,都可以在构建好的霍夫曼树上确定一条从根节点到叶节点的路径,路径上的每个非叶节点都是一个 LR,参数保存在 wo_ 的各行上,训练时,这条路径上的 LR 各自独立进行参数更新。
else,普通softmax()
model.getLoss():
2.test():fasttext.loadModel(),fasttext.test()
fasttext.loadModel():以binary格式加载模型
fasttext.test():
Dictionary::getLine():
addSubwords():
Model::predict() 预测k个最有可能的分类
computeHidden():input上下文,计算隐层向量
如果是hierarchical softmax,dfs()遍历树寻找概率最大值;
否则,findKBest()数组输出前k个最大值
computeOutputSoftmax():计算softmax值,存入output中
输出预测结果
3. quantize()
fasttext.loadModel():以binary格式加载模型
fasttext.quantize();
fasttext.saveModel();
4.printWordVectors()
fasttext.loadModel():以binary格式加载模型
fasttext.getDimension():获取词向量维度
fasttext.getWordVector():输入word,返回vector
5.printSentenceVectors()
fasttext.loadModel():以binary格式加载模型
fasttext.getDimension():获取词向量维度
fasttext.getSentenceVector():
如果是有监督学习:addInputVector()所有词向量累加做average
6.printNgrams()
fasttext.loadModel()::以binary格式加载模型
fasttext.ngramVectors():
7.nn():
fasttext.loadModel():以binary格式加载模型
fasttext.getDictionary()
fasttext.precomputeWordVectors():对词典中的所有word加载vector
输入queryWord
fasttext.getWordVector(queryVec, queryWord);返回输入queryWord的queryVec
fasttext.findNN(wordVectors, queryVec, k, banSet, results);使用优先队列存储最相近的word
8.analogies():
fasttext.loadModel();
fasttext.analogies(k);
9.predict():
Model::predict():给输入数据打上 1 ~ K 个类标签,并输出各个类标签对应的概率。同上
10.dump():
11.printUsage()