Learned-Mixin +H

什么是数据偏差(bias)?

作者举了两个自然语言处理领域的例子,例如在MNLI上训练的隐含模型将仅根据特定关键词的存在或句子对是否包含相同的单词来猜测答案,而在SQuAD上训练的QA模型倾向于选择问题词附近的文本作为答案,而不管上下文。在视觉问答领域,数据偏差对应的就是语言先验,即模型仅根据问题或大部分权重来自问题来决给出预测答案。

解决问题

我们都知道,一个学得很多偏差的模型在训练集上效果很好,但是推广到域外数据集就表现得很差。本文就是为了解决语言先验问题,尽可能的降低训练集产生的偏差对模型的影响,从而使我们的模型可以更好地推广到域外或对抗环境。

但是模型学得偏差是不可避免的,而且很多流行的数据集本身也都存在偏差。在本文中,作者想通过识别数据集偏差,然后不让模型使用该偏差,从而提高模型的域外性能。大题思路就是:对偏差用简单的、受约束的基线方法来显式建模,然后通过基于集成的训练将它们从最终模型中剔除。

网络框架

Learned-Mixin +H_第1张图片

 橙色虚框代表在训练集上训练,绿色虚框代表测试。黑色的箭头是数据传输方向,红色箭头是梯度传播方向(反向传播,优化参数)。本文的模型由两部分构成,一部分是仅偏差模型(Bias-Only Model),一部分是健壮模型(Robust Model)。

仅偏差模型:只以问题作为输入,输出偏差预测答案。所以说模型是一个只有偏差的模型,其学到可以产生语言先验的特征。

健壮模型:输入数据集的问题-图像对,输出预测答案。主模型是采用的up-down模型作为基础模型,我们是采用集合方法(Ensemble)用前面预训练的仅偏差模型来训练这个模型。

总的来说,就是先让仅偏差模型训练产生偏差结果,在根据其产生的预测采用集合方法去训练健壮模型,计算loss,优化参数。为什么这样设置呢?因为我们之前说了,本文想要先识别偏差,再去除偏差。所以第一个仅偏差模型就已经捕获了目标模式(偏差特征),健壮模型也就不会再去学习了,因此在模式不可靠的测试数据上做得更好。

主要公式

我们假设一个预训练的仅偏差预测器h,其中h(xi) = bi,bi代表仅偏差模型的对第i个问题的预测概率,第二个预测函数f,其中f(xiθ)=pipi是类的一个相似的概率分布。我们的目标就是构造一个训练目标来优化参数θ,这样f就会学会选择正确的类,而不去使用仅偏差模型捕获的特征。是我们最后健壮模型输出的第i个实例的预测概率。训练的时候,损失loss使用和真实值通过二叉熵损失函数计算,梯度则通过函数f反向传播。

Bias Product(偏差积):=softmax(log(pi)+log(bi))

Learned-Mixin(混合学习):=softmax (log(pi) + g(xi) log(bi))

该公式较上面的加了个权重g(学习函数),表示已知模型的输入时,我们选择相信仅偏差模型预测值的程度。g计算为softplus(w*hi),其中w是一个学习向量,hi是模型的最后一个隐藏层,例如xi。使用softplus(x) = log(1+ex)函数来防止模型乘以一个负的权重而导致的逆转偏差(更小了)。w用模型的其余参数进行训练。当g(xi) = 1时,这就变成了偏差积。

Learned-Mixin +H(LMH): R = wH(softmax (g(xi) log(bi)))

 R表示我们在loss函数上加的一个熵惩罚项,H(z) =−∑j zj log(zj)为熵,w为超参数。因为上面LM的模型可以将偏差集成到pi中(学到了偏差),然后设置g(xi) = 0。那我们构建仅模型的作用就没了,所以加了一个熵惩罚项来缓解这个问题。惩罚熵会导致偏差分量不均匀,从而对集合产生更大的影响。也就是说,本来LM模型只有用到了偏差,那当g(xi) = 0时,集合模型就丢失了偏差的影响。现在我又加上一个熵惩罚,熵惩罚也用到了偏差,所以即使g(xi) = 0,熵惩罚也不为0(因为softmax(零tensor:logists)会得到一个值全为1/logists.size(0)的新tensor),甚至还会产生更大的误差。

实验结果

Learned-Mixin +H_第2张图片

 上图是在VQA-CP v2.0测试集的结果。由图可知,混合学习模型使VQA-CP的性能提高了约9个点,而熵正则化器又将性能提高3个点。对于混合学习(learned-mixin整体,我们发现g(xi)与偏差的预期准确性密切相关,在测试数据上的斯皮尔曼相关性(spearmanr correlation)为0.77。

Learned-Mixin +H_第3张图片

问题类型和偏见模型对该类型排名最高的答案如上图所示。G是LM模型的g(xi)值,G+是LMH模型的g(xi)值。可见当仅偏差模型预测正确的时候,两个模型的g(xi)都很大,都选择相信偏差。而当仅偏差模型预测错误的时候,g(xi)相较上面都更小,而且LM模型甚至还产生了上面我们说的g(xi)=0的情况,而加了熵惩罚的LMH则不会。

你可能感兴趣的:(自然语言处理,人工智能)