中文短文本分类实例十三-SWEM(Baseline Needs More Love: On Simple Word-Embedding-Based Models and Associated Po)

一.概述

        SWEM(Baseline Needs More Love: On Simple Word-Embedding-Based Models and Associated Pooling

Mechanisms),基于词向量带有池化的简单方法,是Dinghan Shen等2018年的paper。该方案简单有效,embedding +

pooling+ mlp,堪比FastText与传统的CNN、LSTM的encode进行对比,实验表明词嵌入的重要性,以及对现在有的常见任务最重

要的信息,与存在的缺点等。      

        使用SWEM该算法的优点是算法超参量不大,几百KB就可以实现不错的效果,与FastText有异曲同工之妙,可以作为

baseline。   

        github项目地址:  https://github.com/yongzhuo/Keras-TextClassification/tree/master/keras_textclassification/m13_SWEM

二. SWEM模型原理等

        SWEM方法简单有效,就是一个基于词向量的简单pooling,平均池化(AVG)和最大池化(MAX)无需多说,将平均池化

和最大池化拼接起来也算是一种了。此外,这篇paper中提出的另外一种比较新奇的pooling,也就是层次池化(Hierarchical

Pooling)SWEM-Hierarchical-Pooling方案,首先选择一个为N的滑动窗口对文本进行平均池化(可以理解为Ngram),然后再最

大池化,这么看,也没啥新奇的。

 

三. SWEM模型结论等

        SWEM方法简单有效,谈不上什么创新,不过这篇paper的结论还是比较有意思的。

         3.1  word-embedding+pooling对长文本任务有效,而而CNN和LSTM等在短文本任务中效果更佳;

         3.2  情感分类任务比主题模型对词序特征更敏感。paper提出的一个简单的分层池层在情感分析任务上

                取得了与LSTM/CNN相当的结果;

        3.3   自然语言句子配对任务,例如文本蕴涵、文本相似度,简单的词向量池化操作,已经堪比CNN和LSTM了;

        3.4   SWEM中的最大池化,对于捕获主题和关键词,效果不错。

 

四.SWEM代码实现

        1.   SWEM代码实现与TEXT_CNN、FASTTEXT差不多,简单。

        2.   github上代码地址:https://github.com/yongzhuo/Keras-TextClassification

        3.   主要代码附上:

        

from keras.layers import GlobalMaxPooling1D, GlobalAveragePooling1D, Concatenate
from keras_textclassification.base.graph import graph
from keras.layers import Dense, Lambda
from keras.models import Model
import tensorflow as tf


class SWEMGraph(graph):
    def __init__(self, hyper_parameters):
        """
            初始化
        :param hyper_parameters: json,超参
        """
        self.encode_type = hyper_parameters["model"].get("encode_type", "MAX") # AVG, CONCAT, HIERARCHICAL
        self.n_win = hyper_parameters["model"].get("n_win", 3) # n_win=3
        super().__init__(hyper_parameters)


    def create_model(self, hyper_parameters):
        """
            构建神经网络
        :param hyper_parameters:json,  hyper parameters of network
        :return: tensor, moedl
        """
        super().create_model(hyper_parameters)
        embedding = self.word_embedding.output

        def win_mean(x):
            res_list = []
            for i in range(self.len_max-self.n_win+1):
                x_mean = tf.reduce_mean(x[:, i:i + self.n_win, :], axis=1)
                x_mean_dims = tf.expand_dims(x_mean, axis=-1)
                res_list.append(x_mean_dims)
            res_list = tf.concat(res_list, axis=-1)
            gg = tf.reduce_max(res_list, axis=-1)
            return gg

        if self.encode_type=="HIERARCHICAL":
            x = Lambda(win_mean, output_shape=(self.embed_size, ))(embedding)
        elif self.encode_type=="MAX":
            x = GlobalMaxPooling1D()(embedding)
        elif self.encode_type=="AVG":
            x = GlobalAveragePooling1D()(embedding)
        elif self.encode_type == "CONCAT":
            x_max = GlobalMaxPooling1D()(embedding)
            x_avg = GlobalAveragePooling1D()(embedding)
            x = Concatenate()([x_max, x_avg])
        else:
            raise RuntimeError("encode_type must be 'MAX', 'AVG', 'CONCAT', 'HIERARCHICAL'")

        output = Dense(self.label, activation=self.activate_classify)(x)
        self.model = Model(inputs=self.word_embedding.input, outputs=output)
        self.model.summary(132)

 

 

希望对你有所帮助!

 

 

你可能感兴趣的:(中文短文本分类)