大规模文本多元标签分类(XML-CNN)

1.前言

今天分享一篇大规模文本多元标签分类的paper(XML-CNN),论文标题为:Deep Learning for Extreme Multi-label Text Classification,论文发表在2017年的SIGIR上,论文下载链接。

作者解决的是extreme multi-label问题,就是标签种类特别多,涉及上百,上千级别的label。在该问题中,存在数据稀疏问题,即某个label可能就一条数据对应,样本很不平衡;此外,因为标签的量级,往往导致训练和预测计算量很大。基于上述问题,作者提出利用卷积网络(CNN)来解决XML问题。从模型框架上看,尤其站在现在的发展视角来看,本文的model其实很简单。

我分享的目的是:模型在解决实际问题中,应该越简单越好,这也是奥卡姆剃刀原理,我们应该更熟悉简洁模型,探究其存在的一些缺陷,然后改进。

2.模型

模型框架

上图是论文的model框架,是一个标准的CNN模型应用NLP任务的结构。整体分为1.表征层(Embedding layer),2.卷积层(Convolutional layer),3.池化层(Pooling layer),4.全连接层(FC),5分类层(sigmoid output)。

2.1 Embedding layer

Input: 一个文本(docments);
文本有m个word的词组成,词的向量维度为k,词向量可以使用word2vec、fasttext等工具进行预训练过来,也可以按分布初始化随机。
output:文本向量

embedding

2.2 Convolutional layer

应用窗口大小为h(即h-gram)进行卷积操作,卷积的操作主要目的提取label对应的词特征。


convoluation

2.3 Pooling layer

在CNN应用NLP任务,一般主要应用Max-pooling,保留最显著的特征,过滤其他不重要的特征,也降低了模型的复杂度。在本篇论文中,作者用的是 Dynamic Max Pooling,这个概率就是K-max-pooling,取前K个最大值进行保留,Max-pooling只是取最大的一个。应用Dynamic Max Pooling的理由是:在多标签中分类中,仅依靠一个最主要特征是不够的,很可能区分不了两个特别相近的标签,应该把其他显著的特征也加进去学习。

Dynamic Max Pooling的原理其实很简单,将卷积的结果进行排序,然后去topK就可以。


dynamic-max-pooling

2.4 FC layer

作者在pooling layer和output layer加了一个全连接层(FC),文中也称hidden bottleneck layer,意思是在FC层的隐藏单元数量远远小于前后两个layer的隐藏单元数量。应用的原因主要两个:一是pooling layer 单元太多,直接对outputlayer,计算量大;二是太多的隐藏单元,不利于后面的稳健表达和预测;

2.5 Ouput layer

在最后的输出层,也是分类层,文中采用的sigmoid函数,其由采用的损失函数来决定的,损失函数采用的binary entropy loss,也就是逻辑回归中的损失函数。


loss function

以上是模型的整体结构和主要公式,对比标准的CNN来看,作者只是在pooling和分类层进行了修改,model很简洁。接着说说实验部分。

5 Experiment

实验数据集,作者用了6个公开数据集,可以说实验是很充分,label个数由百个到60万。详细统计看下图:


dataset

在评价中指标中,作者使用了准确率和NDCG(信息检索领域的评价指标).


evaluation

在模型对比上,作者用了7个对比模型,包括传统的方法和两个CNN方法,其结果按两个评价指标如下图所示:


P@K
NDCG@K

结果分析:从实验结果看,作者提出的XML-CNN方法的确很有效,在有些数据集上有达30%的提升;在对比常规的CNN方法,也有约4%的提升,说明作者提出的动态池化和瓶颈层是有效的。

此外,作者还对比了模型的运行效率,说明XML-CNN运行效率还是很不错的,结果如下图:


cost time

6 结语

整体来说,本篇应算是深度学习解决multi-label任务的开篇之作,取的效果是很不错的,模型也很简单。站在本篇论文的肩上,其实还有很多可优化的地方,比如考虑label之间的层级关系,或者在embedding上在提出一些改进的trick等。这些问题,读者有兴趣可以去尝试。这篇论文在GitHub有一些复现的代码,推荐一个siddsax/XML-CNN。

你可能感兴趣的:(大规模文本多元标签分类(XML-CNN))