fasttext源码位置
1:为了学习词向量(向量表示),我们可以使用fasttext.train_unsupervised函数,像下面这样:
import fasttext
# data.txt :准备语料时,只需要去掉原始数据中的label标签即可。
# Skipgram model :
俩种方式
model = fasttext.skipgram('data.txt','model')
model = fasttext.train_unsupervised('data.txt', model='skipgram')
# or, cbow model :
model = fasttext.cbow('data.txt','model')
model = fasttext.train_unsupervised('data.txt', model='cbow')
2:保存和加载模型对象:
model.save_model("model.bin")
model = fasttext.load_model('model.bin')
3:模型运用
# 加载前面训练好的模型 model.bin
model = fasttext.load_model("model.bin")
print (model.words) # model 中的词汇列表
print (model["king"]) # "king" 的词向量
分类过程:
fasttext在进行文本分类时,huffmax树叶子节点处是每一个类别标签的词向量。在训练过程中,训练语料的每一个词也会得到响应的词向量。输入为一个window 内的词对应的词向量,隐藏层为这几个词的线性相加。相加的结果作为该文档的向量。再通过softmax层得到预测标签。结合文档真实标签计算 loss,梯度与迭代更新词向量(优化词向量的表达)。
from fastText import train_supervised, load_model
流程:
1:数据准备 fasttex_train.txt
处理后的数据: 每行代表一个文本,以\n结尾,文本以空格分隔单词,如下所示,文本今天天气真的太好了处理后为:
今天 天气 真的 太好 了 __label__1
2.训练模型
import fasttext
#第一个参数是前面得到的 fasttex_train.txt
model = train_supervised(input=fasttex_train.txt , epoch=10, lr=0.1, wordNgrams=2, minCount=1, loss="softmax")
3.测试模型和使用模型分类
import fasttext
# 测试模型 其中 fasttext_test.txt 就是测试数据,格式和 fasttext_train.txt 一样
result = model.test("fasttext_test.txt")
print( "准确率:",result.precision)
print( "回归率:",result.recall)
4:模型预测
#参数解释
#其中,k=3用来指定获取前3个概率最高的结果,默认k=1。
#如果想要预测多个句子,可以传入一个字符串数组,如下:
model.predict(["Which baking dish is best to bake a banana bread ?", "Why not put knives in the dishwasher?"], k=3)
#测试文本类别,需要将测试的文本进行中文分词,然后使用空格连接起来。
segs = jieba.lcut(test_text)
segs = filter(lambda x:x not in stop_words, segs)
test_text = " ".join(segs)
# 预测,返回类别类型以及概率值。
lables, proba = model.predict(test_text)
print(lables, proba)
5:模型保存,俩种方式
当您想要保存监督模型文件时,fastText可以通过牺牲一点点性能来压缩它以获得更小的模型文件,model_filename.ftz的大小比model_filename.bin小得多。
# 保存模型
model.save_model("fasttext.model.bin")
# 压缩模型
model.quantize(input=train_data, qnorm=True, retrain=True, cutoff=100000)
print_results(*model.test(valid_data))
model.save_model("fasttext.model.ftz") # 保存压缩后的模型
6:模型加载
# 加载模型
model= fasttext.load_model("fasttext.model.bin",label_prefix = "__label__")
参数解释:
1:分类模型参数
def train_supervised(input, lr=0.1, dim=100,
ws=5, epoch=5, minCount=1,
minCountLabel=0, minn=0,
maxn=0, neg=5, wordNgrams=1,
loss="softmax", bucket=2000000,
thread=12, lrUpdateRate=100,
t=1e-4, label="__label__",
verbose=2, pretrainedVectors=""):
"""
训练一个监督模型, 返回一个模型对象
@param input: 训练数据文件路径
@param lr: 学习率
@param dim: 向量维度
@param ws: cbow模型时使用
@param epoch: 次数
@param minCount: 词频阈值, 小于该值在初始化时会过滤掉
@param minCountLabel: 类别阈值,类别小于该值初始化时会过滤掉
@param minn: 构造subword时最小char个数
@param maxn: 构造subword时最大char个数
@param neg: 负采样
@param wordNgrams: n-gram个数
@param loss: 损失函数类型, softmax, ns: 负采样, hs: 分层softmax
@param bucket: 词扩充大小, [A, B]: A语料中包含的词向量, B不在语料中的词向量
@param thread: 线程个数, 每个线程处理输入数据的一段, 0号线程负责loss输出
@param lrUpdateRate: 学习率更新
@param t: 负采样阈值
@param label: 类别前缀
@param verbose: ??
@param pretrainedVectors: 预训练的词向量文件路径, 如果word出现在文件夹中初始化不再随机
@return model object
"""
2:词向量训练
def train_unsupervised(input, model="skipgram", lr=0.05, dim=100,
ws=5, epoch=5, minCount=5,
minCountLabel=0, minn=3,
maxn=6, neg=5, wordNgrams=1,
loss="ns", bucket=2000000,
thread=12, lrUpdateRate=100,
t=1e-4, label="__label__",
verbose=2, pretrainedVectors=""):
"""
训练词向量,返回模型对象
输入数据不要包含任何标签和使用标签前缀
@param model: 模型类型, cbow/skipgram两种
其他参数参考train_supervised()方法
@return model
"""
pass
3:模型压缩
def model.quantize(self, input=None, qout=False, cutoff=0,
retrain=False, epoch=None, lr=None, thread=None,
verbose=None, dsub=2, qnorm=False):
"""
减小模型大小和内存占用
@param input: 训练数据文件路径
@param qout: 是否修剪输出层
@param cutoff: 重新训练的词和n-gram的个数
@param retrain: 是否重新训练
@param epoch: 次数
@param lr: 学习率
@param thread: 线程个数
@param verbose: ?
@param dsub: 压缩方法将向量分成几块, 每块大小
@param qnorm: 是否归一化(l2范数)
"""
pass
4:模型预测
def model.predict(self, text, k=1, threshold=0.0):
"""
模型预测,给定文本预测分类
@param text: 字符串, 需要utf-8
@param k: 返回标签的个数
@param threshold 概率阈值, 大于该值才返回
@return 标签列表, 概率列表
"""
[1]: http://meta.math.stackexchange.com/questions/5020/mathjax-basic-tutorial-and-quick-reference
[2]: https://mermaidjs.github.io/
[3]: https://mermaidjs.github.io/
[4]: http://adrai.github.io/flowchart.js/