论文解读: Few-Shot Text Classification with Induction Network

目的

在文本分类中,经常碰到一些很少出现过的类别或这样不均衡的类别样本,而且当前的few-shot技术经常会将输入的query和support的样本集合进行sample-wise级别的对比。但是,如果跟同一个类别下的不同表达的样本去对比的时候产生的效果就不太好。
因此,文章的作者就提出了,通过学习sample所属于的类别的表示得到class-wise的向量,然后跟输入的query进行对比,这样能比state-of-the-art的模型提高3%正确率,同时泛化的效率也更高。

模型

模型分为三个模块:Encoder, Induction 和 Relation. 大概的架构如下图.


1.png

Data:
构建数据集的时候会把样本分为support set—S 和 query set — Q,support set就是用来训练参数的,query set就是用来模拟真实请求,计算loss的;
support set是从C个Class中,每个class抽出K个样本生成的,那么在C个class中剩余的部分就作为query set.
Encoder Module:
Encoder阶段就是将support set的文本进行encoding; 首先,会经过Bi-LSTM得到这样句子的表示;
假如:support set的样本是m (m=C * K),LSTM输出的表示的维度是u的话,经过Bi-LSTM会得到H,其维度为(m, T, 2u).

2.png

利用Self-Attention得到最终的表示,也希望通过attention的方式来决定哪些hidden state, ht更值得学习。于是,作者就通过将Bi-LSTM得到的表示H,经过线性组合和tanh变换,再做Softmax处理得到attention score — a, 其维度是(m,T);
然后将a(m,T) 乘以原来的每个H(m, T, 2u)的ht,并且相加,得到了e矩阵,其维度变成了(m, 2u).


3.png

Induction Module:
在得到每个样本的表示后,es矩阵(m, 2u),我们下一步需要将其向上抽象成class的表示了;

4.png

首先,通过matrix transformation, Ws(2u,2u),将样本的表示进行变形,从实验结果看,这样能让不同类别的样本区分得更好。同时,由于matrix对于所有样本向量都是共用的,不管什么样的样本size都可以支持了。所以,将Ws(2u,2u)乘以es矩阵(m, 2u)得到es'(m,2u)
5.png

其次,为了确保class的表示已经囊括了这个sample feature vector,我们还会动态地去调整这个coefficients — d, 这个d是在0,1之间分布,用来确保这个sample的类别所属。因此,这里会对耦合系数b进行softmax(在大于一定值后,随着input的增加,softmax的score的值增加得越大); 注意,这个耦合系数b的初始值为0,然后会通过学习来更新。(后面会提到)
6.png

然后,再通过加权聚合来得到class的表示ci',其维度是(k, 2u)
7.png

之后,通过squashing函数将ci'的表示进行压缩,这种压缩不会改变正负但可以减少区间,得到ci其维度是(k, 2u)
8.png

最后,回到刚才提到的b的更新,其实就是动态规划,如果这个样本是属于这个类别的话,这个sample的向量就应该得到更大的值,而且在不同的类别的话,这个值就应该更小;
9.png

总的来说,通过多次迭代后,不但可以让不同class之间的表示得到区分,同时,同一个class下的样本贡献程度也会通过学习后变得不一样。同时,这里的Ws(2u,2u)也会给予后面预测去使用。

Relation Module:
在得到了ci(k, 2u)后,我们就可以计算ci与query set的相关性分数了,作者采用的是neural tensor layer的方式。
首先,从其中一个class开始,假设是ci(k, 2u),先做一次matrix transformation, 将Ci转置得到CiT(2u,k),然后乘以M[1:h],其维度(k,n), 得到中间结果的维度为(2u, n),然后乘以query set, eq(n, 2u)得到结果的维度为(2u, 2u),然后再过一个RELU函数.

10.png

然后,将v(ci,eq)的结果经过全联接,再经过一个sigmoid函数,得到一个第i个class与query的相似度
11.png

目标函数

最后,把riq的值和yq做对,如果匹配就是1,否则就是0,计算query set的loss;


12.png

你可能感兴趣的:(论文解读: Few-Shot Text Classification with Induction Network)