之前写了一篇fasttext文本分类的文章,三个类别的准确率达到90+%,这篇文章主要是想测试一下TextCNN在文本分类任务上的效果,与fasttext对比,孰优孰劣。
代码已上传至GitHub:TextCNN文本分类
torch==1.9.0
gensim==3.8.3
其他的缺啥装啥吧
gensim4.x版本与3.x版本有些参数名变了,报错了百度下都可以解决。
由于数据集太大了,无法上传至GitHub,数据集链接:fasttext分类数据集
百度云:链接
提取码:96fu
PS: 与fasttext使用的是同一份数据。
class TextCNN(nn.Module):
def __init__(self, config):
super(TextCNN, self).__init__()
self.embedding = nn.Embedding(config.vocab_size, config.embedding_size)
if config.use_pretrained_w2v:
self.embedding.weight.data.copy_(config.embedding_pretrained)
self.embedding.weight.requires_grad = True
self.convs = nn.ModuleList([nn.Conv2d(1, config.kenel_num, (k, config.embedding_size)) for k in config.kenel_size])
self.dropout = nn.Dropout(config.dropout)
self.fc = nn.Linear(config.kenel_num * len(config.kenel_size), config.num_classes)
def forward(self, x):
x = self.embedding(x)
x = x.unsqueeze(1)
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]
x = [F.max_pool1d(line, line.size(2)).squeeze(2) for line in x]
x = torch.cat(x, 1)
x = self.dropout(x)
out = self.fc(x)
out = F.log_softmax(out, dim=1)
return out
第二列6个的矩形便是self.convs
这个地方有些人会把它分开来写,效果也是一样的:
self.conv1 = nn.Conv2d(1, config.kenel_num, (config.kernel_size[0], config.embedding_size))
self.conv2 = nn.Conv2d(1, config.kenel_num, (config.kernel_size[1], config.embedding_size))
self.conv3 = nn.Conv2d(1, config.kenel_num, (config.kernel_size[2], config.embedding_size))
相应地,如果你这样分开来写,在forward
的时候也需要写三遍。
由于文本的词向量每个维度代表的含义不一样,在不同的维度上进行运算得到的结果没有实际意义。因此,卷积核的长度必须与词向量的维度一致,并且,卷积核移动的方向是从上至下移动的。
其实很好理解:
因为卷积核的长度等于词向量的维度,上图中,从上至下卷积,就是在词与词之间卷积,得到语义信息。如果卷积核的宽度为2,卷完I like
之后,如果步长为1
,则接着卷like this
,有没有发现,这个非常像bi-gram
的操作。使用不同size
的卷积核,得到的特征所代表的含义也不同(可以借助图像的卷积来理解,不同的卷积核可以卷出不同方向上的直线)。
这里的参数config.use_pretrained_w2v
用来控制是否使用预训练的词向量,对比一下使用预训练的词向量作为embedding层
的权重与随机初始化embedding层
权重的结果
深度学习模型一般都逃不过这两个东西吧。
def build_word2id(lists):
maps = {}
for item in lists:
if item not in maps:
maps[item] = len(maps)
return maps
word2id['PAD'] = len(word2id)
id2word = {word2id[w]: w for w in word2id}
lists
为训练数据分词后组成的列表。
注意:为了方便,我没有严格操作,正常的操作需要加UNK
。因为测试集中有可能存在未登录词,以及训练词向量时也会过滤掉一些低频词,导致这些词不在word2id
中。
def load_w2v():
train_save_path = './data/three_class/train.csv'
dev_save_path = './data/three_class/dev.csv'
test_save_path = './data/three_class/test.csv'
data = concat_all_data(train_save_path, dev_save_path, test_save_path)
model_save_path = './checkpoints/w2v_model.bin'
vec_save_path = './checkpoints/w2v_model.txt'
if not os.path.exists(vec_save_path):
sent = [str(row).split(' ') for row in data['text_seg']]
phrases = Phrases(sent, min_count=5, progress_per=10000)
bigram = Phraser(phrases)
sentence = bigram[sent]
cores = multiprocessing.cpu_count()
w2v_model = Word2Vec(
min_count=2,
window=2,
size=300,
sample=6e-5,
alpha=0.03,
min_alpha=0.0007,
negative=15,
workers=cores-1,
iter=7)
t0 = time()
w2v_model.build_vocab(sentence)
t1 = time()
print('build vocab cost time: {}s'.format(t1-t0))
w2v_model.train(
sentence,
total_examples=w2v_model.corpus_count,
epochs=20,
report_delay=1
)
t2 = time()
print('train w2v model cost time: {}s'.format(t2-t1))
w2v_model.save(model_save_path)
w2v_model.wv.save_word2vec_format(vec_save_path, binary=False)
如果你的gensim版本不同的话,报错自行解决。
def get_pretrainde_w2v():
w2v_path = './checkpoints/w2v_model.txt'
w2v_model = KeyedVectors.load_word2vec_format(w2v_path, binary=False)
word2id_path = './data/three_class/word2id.json'
id2_word_path = './data/three_class/id2word.json'
with open(word2id_path, 'r', encoding='utf-8') as f:
word2id = json.load(f)
with open(id2_word_path, 'r', encoding='utf-8') as f:
id2word = json.load(f)
vocab_size = len(word2id)
embedding_size = 300
weight = torch.zeros(vocab_size, embedding_size)
for i in range(len(w2v_model.index2word )):
try:
index = word2id[w2v_model.index2word [i]]
except:
continue
weight[index, :] = torch.from_numpy(w2v_model.get_vector(
id2word[str(word2id[w2v_model.index2word [i]])]))
# print(weight)
return weight
不出意外的话,加载过程中会报错,大概这样:
raise ValueError("invalid vector on line %s (is this really the text format?)" % line_no)
这是因为这份数据没有经过严格的清洗,你看:
有些词竟然用下划线连着,不过这种词不会影响我们加载,最主要的原因是分词的时候,文本中存在 \t、\n
等字符,不过还好,这种情况不多,遇到上面的错误的时候,打开生成的w2v_model.txt
,根据报错信息提示的行号,改一下对应的行就好了,把那些空行去掉,再次加载就不会报错了。
train epoch: 10, train_acc: 60.4980084432466%
eval epoch: 10, step: 25, loss: 0.18598
eval epoch: 10, step: 50, loss: 0.18880
eval epoch: 10, step: 75, loss: 0.18921
eval epoch: 10, step: 100, loss: 0.18730
eval epoch: 10, train_acc: 69.08880414056115%
开始测试:
test acc: 60.761789600967354%
开始预测:
电气 试验 本书 共 七章 主要 内容 电气 绝缘 基础理论 知识 液体 固体 组合 绝缘 电 特性 电气设备 交流 耐奈 试验 几个 问题
预测正确,预测的label:工业技术, 正确的类别是: 工业技术
测试集准确率只有60%
,未免太低了吧。
一度怀疑自己模型搞错了。
train epoch: 10, train_acc: 82.52915782192859%
eval epoch: 10, step: 25, loss: 0.10667
eval epoch: 10, step: 50, loss: 0.10773
eval epoch: 10, step: 75, loss: 0.10916
eval epoch: 10, step: 100, loss: 0.10584
eval epoch: 10, train_acc: 83.21983110868973%
开始测试:
test acc: 82.49697702539298%
开始预测:
电气 试验 本书 共 七章 主要 内容 电气 绝缘 基础理论 知识 液体 固体 组合 绝缘 电 特性 电气设备 交流 耐奈 试验 几个 问题
预测正确,预测的label:工业技术, 正确的类别是: 工业技术
准确率提高到了82%
,总体来说还行。
1、TextCNN训练速度比较快,在三个类别的数据上准确率达到82%
2、与fasttext相比,TextCNN的效果要差一些,fastetxt准确率93%