一.概述
SWEM(Baseline Needs More Love: On Simple Word-Embedding-Based Models and Associated Pooling
Mechanisms),基于词向量带有池化的简单方法,是Dinghan Shen等2018年的paper。该方案简单有效,embedding +
pooling+ mlp,堪比FastText与传统的CNN、LSTM的encode进行对比,实验表明词嵌入的重要性,以及对现在有的常见任务最重
要的信息,与存在的缺点等。
使用SWEM该算法的优点是算法超参量不大,几百KB就可以实现不错的效果,与FastText有异曲同工之妙,可以作为
baseline。
github项目地址: https://github.com/yongzhuo/Keras-TextClassification/tree/master/keras_textclassification/m13_SWEM
二. SWEM模型原理等
SWEM方法简单有效,就是一个基于词向量的简单pooling,平均池化(AVG)和最大池化(MAX)无需多说,将平均池化
和最大池化拼接起来也算是一种了。此外,这篇paper中提出的另外一种比较新奇的pooling,也就是层次池化(Hierarchical
Pooling)SWEM-Hierarchical-Pooling方案,首先选择一个为N的滑动窗口对文本进行平均池化(可以理解为Ngram),然后再最
大池化,这么看,也没啥新奇的。
三. SWEM模型结论等
SWEM方法简单有效,谈不上什么创新,不过这篇paper的结论还是比较有意思的。
3.1 word-embedding+pooling对长文本任务有效,而而CNN和LSTM等在短文本任务中效果更佳;
3.2 情感分类任务比主题模型对词序特征更敏感。paper提出的一个简单的分层池层在情感分析任务上
取得了与LSTM/CNN相当的结果;
3.3 自然语言句子配对任务,例如文本蕴涵、文本相似度,简单的词向量池化操作,已经堪比CNN和LSTM了;
3.4 SWEM中的最大池化,对于捕获主题和关键词,效果不错。
四.SWEM代码实现
1. SWEM代码实现与TEXT_CNN、FASTTEXT差不多,简单。
2. github上代码地址:https://github.com/yongzhuo/Keras-TextClassification
3. 主要代码附上:
from keras.layers import GlobalMaxPooling1D, GlobalAveragePooling1D, Concatenate
from keras_textclassification.base.graph import graph
from keras.layers import Dense, Lambda
from keras.models import Model
import tensorflow as tf
class SWEMGraph(graph):
def __init__(self, hyper_parameters):
"""
初始化
:param hyper_parameters: json,超参
"""
self.encode_type = hyper_parameters["model"].get("encode_type", "MAX") # AVG, CONCAT, HIERARCHICAL
self.n_win = hyper_parameters["model"].get("n_win", 3) # n_win=3
super().__init__(hyper_parameters)
def create_model(self, hyper_parameters):
"""
构建神经网络
:param hyper_parameters:json, hyper parameters of network
:return: tensor, moedl
"""
super().create_model(hyper_parameters)
embedding = self.word_embedding.output
def win_mean(x):
res_list = []
for i in range(self.len_max-self.n_win+1):
x_mean = tf.reduce_mean(x[:, i:i + self.n_win, :], axis=1)
x_mean_dims = tf.expand_dims(x_mean, axis=-1)
res_list.append(x_mean_dims)
res_list = tf.concat(res_list, axis=-1)
gg = tf.reduce_max(res_list, axis=-1)
return gg
if self.encode_type=="HIERARCHICAL":
x = Lambda(win_mean, output_shape=(self.embed_size, ))(embedding)
elif self.encode_type=="MAX":
x = GlobalMaxPooling1D()(embedding)
elif self.encode_type=="AVG":
x = GlobalAveragePooling1D()(embedding)
elif self.encode_type == "CONCAT":
x_max = GlobalMaxPooling1D()(embedding)
x_avg = GlobalAveragePooling1D()(embedding)
x = Concatenate()([x_max, x_avg])
else:
raise RuntimeError("encode_type must be 'MAX', 'AVG', 'CONCAT', 'HIERARCHICAL'")
output = Dense(self.label, activation=self.activate_classify)(x)
self.model = Model(inputs=self.word_embedding.input, outputs=output)
self.model.summary(132)
希望对你有所帮助!