fasttext是facebook开源的一个词向量与文本分类工具,在学术上没有太多创新点,好处是模型简单,训练速度非常快。简单尝试可以发现,用起来还是非常顺手的,做出来的结果也不错,可以达到上线使用的标准。
简单说来,fastText做的事情,就是把文档中所有词通过lookup table变成向量,取平均之后直接用线性分类器得到分类结果。fastText和ACL-15上的deep averaging network(DAN,如下图)比较相似,是一个简化的版本,去掉了中间的隐层。论文指出了对一些简单的分类任务,没有必要使用太复杂的网络结构就可以取得差不多的结果。
fastText论文中提到了一些tricks
可以通过pip install fasttext
安装包含fasttext python的接口的package。
注意:如果直接pip安装报错的话,可以通过https://www.lfd.uci.edu/~gohlke/pythonlibs/#fasttext下载对应版本的whl文件,然后在下载目录中用命令pip3 install fasttext-0.9.2-cp36-cp36m-win_amd64.whl
安装。
fastText做文本分类要求文本是如下的存储形式:
__label__2 , birchas chaim , yeshiva birchas chaim is a orthodox jewish mesivta high school in lakewood township new jersey . it was
founded by rabbi shmuel zalmen stein in 2001 after his father rabbi
chaim stein asked him to open a branch of telshe yeshiva in lakewood .
as of the 2009-10 school year the school had an enrollment of 76
students and 6 . 6 classroom teachers ( on a fte basis ) for a
student–teacher ratio of 11 . 5 1 .
__label__6 , motor torpedo boat pt-41 , motor torpedo boat pt-41 was a pt-20-class motor torpedo boat of the united states navy built by the
electric launch company of bayonne new jersey . the boat was laid down
as motor boat submarine chaser ptc-21 but was reclassified as pt-41
prior to its launch on 8 july 1941 and was completed on 23 july 1941 .__label__11 , passiflora picturata , passiflora picturata is a species of passion flower in the passifloraceae family .
__label__13 , naya din nai raat , naya din nai raat is a 1974 bollywood drama film directed by a . bhimsingh . the film is famous as
sanjeev kumar reprised the nine-role epic performance by sivaji
ganesan in navarathri ( 1964 ) which was also previously reprised by
akkineni nageswara rao in navarathri ( telugu 1966 ) . this film had
enhanced his status and reputation as an actor in hindi cinema .
其中前面的__label__是前缀,也可以自己定义,__label__后接的为类别。
我们定义我们的5个类别分别为:
import jieba
import pandas as pd
import random
# 设定各类类别映射,如'technology'为1,'car'为2……
cate_dic = {'technology':1, 'car':2, 'entertainment':3, 'military':4, 'sports':5}
# 读取数据
df_technology = pd.read_csv("/jhub/students/data/course11/项目2/origin_data/technology_news.csv", encoding='utf-8')
df_technology = df_technology.dropna()
df_car = pd.read_csv("/jhub/students/data/course11/项目2/origin_data/car_news.csv", encoding='utf-8')
df_car = df_car.dropna()
df_entertainment = pd.read_csv("/jhub/students/data/course11/项目2/origin_data/entertainment_news.csv", encoding='utf-8')
df_entertainment = df_entertainment.dropna()
df_military = pd.read_csv("/jhub/students/data/course11/项目2/origin_data/military_news.csv", encoding='utf-8')
df_military = df_military.dropna()
df_sports = pd.read_csv("/jhub/students/data/course11/项目2/origin_data/sports_news.csv", encoding='utf-8')
df_sports = df_sports.dropna()
# 转换为list列表的形式
technology = df_technology.content.values.tolist()[1000:21000]
car = df_car.content.values.tolist()[1000:21000]
entertainment = df_entertainment.content.values.tolist()[:20000]
military = df_military.content.values.tolist()[:20000]
sports = df_sports.content.values.tolist()[:20000]
(也可以直接用之前预处理好的数据)
stopwords=pd.read_csv("/jhub/students/data/course11/项目2/origin_data/stopwords.txt",index_col=False,quoting=3,sep="\t",names=['stopword'], encoding='utf-8')
stopwords=stopwords['stopword'].values
def preprocess_text(content_lines, sentences, category):
for line in content_lines:
try:
segs=jieba.lcut(line)
# 去标点、停用词等
segs = list(filter(lambda x:len(x)>1, segs))
segs = list(filter(lambda x:x not in stopwords, segs))
# 将句子处理成 __label__1 词语 词语 词语 ……的形式
sentences.append("__label__"+str(category)+" , "+" ".join(segs))
except Exception as e:
print(line)
continue
#生成训练数据
sentences = []
preprocess_text(technology, sentences, cate_dic['technology'])
preprocess_text(car, sentences, cate_dic['car'])
preprocess_text(entertainment, sentences, cate_dic['entertainment'])
preprocess_text(military, sentences, cate_dic['military'])
preprocess_text(sports, sentences, cate_dic['sports'])
# 随机打乱数据
random.shuffle(sentences)
# 将数据保存到train_data.txt中
print("writing data to fasttext format...")
out = open('../../tmp/train_data.txt', 'w', encoding='utf-8')
for sentence in sentences:
out.write(sentence+"\n")
print("done!")
import fasttext
"""
训练一个监督模型, 返回一个模型对象
@param input: 训练数据文件路径
@param lr: 学习率
@param dim: 向量维度
@param ws: cbow模型时使用
@param epoch: 次数
@param minCount: 词频阈值, 小于该值在初始化时会过滤掉
@param minCountLabel: 类别阈值,类别小于该值初始化时会过滤掉
@param minn: 构造subword时最小char个数
@param maxn: 构造subword时最大char个数
@param neg: 负采样
@param wordNgrams: n-gram个数
@param loss: 损失函数类型, softmax, ns: 负采样, hs: 分层softmax
@param bucket: 词扩充大小, [A, B]: A语料中包含的词向量, B不在语料中的词向量
@param thread: 线程个数, 每个线程处理输入数据的一段, 0号线程负责loss输出
@param lrUpdateRate: 学习率更新
@param t: 负采样阈值
@param label: 类别前缀
@param verbose: ??
@param pretrainedVectors: 预训练的词向量文件路径, 如果word出现在文件夹中初始化不再随机
@return model object
"""
classifier = fasttext.train_supervised(input='../../tmp/train_data.txt', dim=100, epoch=5,
lr=0.1, wordNgrams=2, loss='softmax')
classifier.save_model('../../tmp/classifier.model')
result = classifier.test('../../tmp/train_data.txt')
print('P@1:', result[1])
print('R@1:', result[2])
print('Number of examples:', result[0])
P@1: 0.9807066613391175
R@1: 0.9807066613391175
Number of examples: 87595
label_to_cate = {'__label__1':'technology', '__label__2':'car', '__label__3':'entertainment',
'__label__4':'military', '__label__5':'sports'}
texts = '这 是 中国 制造 宝马 汽车'
labels = classifier.predict(texts)
# print(labels)
print(label_to_cate[labels[0][0]])
car
labels = classifier.predict(texts, k=3)
label, proba = labels[0], labels[1]
for label, proba in zip(label, proba):
print('预测:%s\t概率为: %f' % (label_to_cate[label], proba))
预测:car 概率为: 0.999747
预测:military 概率为: 0.000273
预测:technology 概率为: 0.000010
def preprocess_text_unsupervised(content_lines, sentences):
for line in content_lines:
try:
segs=jieba.lcut(line)
segs = list(filter(lambda x:len(x)>1, segs))
segs = list(filter(lambda x:x not in stopwords, segs))
# 处理成 词语 词语 词语…… 的形式
sentences.append(" ".join(segs))
except Exception as e:
print(line)
continue
#生成无监督训练数据
sentences = []
preprocess_text_unsupervised(technology, sentences)
preprocess_text_unsupervised(car, sentences)
preprocess_text_unsupervised(entertainment, sentences)
preprocess_text_unsupervised(military, sentences)
preprocess_text_unsupervised(sports, sentences)
print("writing data to fasttext unsupervised learning format...")
out = open('../../tmp/unsupervised_train_data.txt', 'w', encoding='utf-8')
for sentence in sentences:
out.write(sentence+"\n")
print("done!")
import fasttext
# Skipgram model
model = fasttext.train_unsupervised('../../tmp/unsupervised_train_data.txt', model='skipgram')
print(model.words[:10]) # list of words in dictionary
# CBOW model
model = fasttext.train_unsupervised('../../tmp/unsupervised_train_data.txt', model='cbow')
print(model.words[:10]) # list of words in dictionary
# 查看某个词的词向量
print(model['赛季'])
def preprocess_text_unsupervised(content_lines, sentences):
for line in content_lines:
try:
segs=jieba.lcut(line)
segs = list(filter(lambda x:len(x)>1, segs))
segs = list(filter(lambda x:x not in stopwords, segs))
# gensim 输入格式为 [词, 词, 词]
sentences.append(segs)
except Exception as e:
print(line)
continue
#生成无监督训练数据
sentences = []
preprocess_text_unsupervised(technology, sentences)
preprocess_text_unsupervised(car, sentences)
preprocess_text_unsupervised(entertainment, sentences)
preprocess_text_unsupervised(military, sentences)
preprocess_text_unsupervised(sports, sentences)
from gensim.models import word2vec
model = word2vec.Word2Vec(sentences, size=100, window=5, min_count=5, workers=4)
model.save("../../tmp/gensim_word2vec.model")
# 寻找相似词语
print(model.wv.most_similar('赛季'))