天池学习赛之新闻文本分类——Task4深度学习fastText

本篇博客是天池学习赛之新闻文本分类系列的第四篇,主要是总结基于深度学习的文本分类方法fastText。

文章目录

        • 一、FastText
        • 二、基于FastText的文本分类
          • 2.1 fasttext库的安装
          • 2.2 fasttext库的使用

一、FastText

FastText是一种典型的深度学习词向量的表示方法,它非常简单通过Embedding层将单词映射到稠密空间,然后将句子中所有的单词在Embedding空间中进行平均,进而完成分类操作。所以FastText是一个三层的神经网络,输入层、隐含层和输出层。

天池学习赛之新闻文本分类——Task4深度学习fastText_第1张图片

FastText在文本分类任务上,是优于TF-IDF的:

  • FastText用单词的Embedding叠加获得的文档向量,将相似的句子分为一类
  • FastText学习到的Embedding空间维度比较低,可以快速进行训练

二、基于FastText的文本分类

2.1 fasttext库的安装

Windows直接使用pip安装会报错,去https://www.lfd.uci.edu/~gohlke/pythonlibs/#fasttext下载对应版本的whl文件,再使用pip安装即可。

天池学习赛之新闻文本分类——Task4深度学习fastText_第2张图片
2.2 fasttext库的使用
import numpy as np
import pandas as pd
import warnings
import fasttext
from sklearn.metrics import f1_score
# 忽略警告
warnings.filterwarnings('ignore')
# 多行显示
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
# 导入数据集
train_df = pd.read_csv('./data/train.csv', sep='\t', nrows=200000)
test_df = pd.read_csv('./data/test.csv', sep='\t', nrows=50000)
# 合并数据集
data_df = pd.concat([train_df, test_df], axis=0, ignore_index=True)
# 格式转换
train_count = train_df.shape[0]
test_count = test_df.shape[0]
val_count = int(train_df.shape[0] * 0.20)
data_df['label_ft'] = '__label__' + data_df['label'].astype(str)
data_df[['text','label_ft']].iloc[:-(val_count+test_count)].to_csv('train.csv', index=None, header=None, sep='\t')
# 模型训练
model = fasttext.train_supervised('train.csv', lr=1.0, wordNgrams=2, verbose=2, minCount=1, epoch=25, loss="hs")
# 验证和预测
val_pred = [model.predict(x)[0][0].split('__')[-1] for x in data_df.iloc[-(val_count+test_count):-test_count]['text']]
test_pred = [model.predict(x)[0][0].split('__')[-1] for x in data_df.iloc[-test_count:]['text']]
print(f1_score(data_df['label'].values[-(val_count+test_count):-test_count].astype(str), val_pred, average='macro'))
# 提交结果
sub = pd.Series(test_pred,name='label')
sub = pd.to_numeric(sub).astype('int').to_frame() # 将object类型的Series转化为int类型DataFrame
sub.to_csv('./results/sub20200727-1.csv', index=False)

最终结果:

  • 验证集F1分数:0.9147696363768255
  • 测试集F1分数:0.9127

你可能感兴趣的:(自然语言处理)