ICML 2017 Best Paper理解

一、背景

2017年ICML的最佳论文奖被来自于斯坦福大学的Pang Wei Koh和Percy Liang拿下,论文名是《Understanding Black-box Predictions via Influence Functions》,研究内容是关于神经网络的可解释性。论文地址:[1]。

正如论文题目所表明的,本文的核心是influence function(影响函数)。这是一个来自于稳健统计学(robust statistics)的概念[2],功能是告诉我们当upweight训练样本一个无穷小的量,模型的参数是怎么变化的。

在了解本文是怎样使用influence function之前,我们先来认识下本文的背景。深度学习,尤其是神经网络,近些年来在很多领域都取得了很好的结果,如图像识别、自然语言处理和语音识别等领域。但是,一直以来困扰深度学习的一个问题是深度学习的可解释性太差了。深度学习的研究人员通过改进神经网络结构、调节模型参数和增加正则化方法等,让模型的结果越来越好,但他们并不知道模型内部做了什么让结果变得更好,因此,深度学习也被人戏称为“炼金业”,深度学习的研究人员也被称为“炼金术士”,所谓的炼金也就是指调参。而本文就是通过influence function很好地了解深度学习的“黑盒”,让深度学习拥有了一些解释性。

二、方法

本文的特色之一是其严格、完整的形式化证明,因此先来交待一下本文所研究问题的数学定义。本文研究的是从输入空间X(如图像数据)到输出空间Y(如标签数据)的一般性的预测问题。记训练样本$z_{1}, ..., z_{n}$,其中$z_{i} = (x_{i}, y_{i}) \in X \times Y$。那么经验风险最小化为:

首先,本文假设风险是二阶可微,并严格的凸函数。因为Influence function是从训练数据的视角去观察模型的学习过程,所以先尝试更改训练数据。本文共做了两方面的数据更改,一是upweight训练数据,二是disturb训练数据。

(1)upweight训练数据。

初衷是为了观察删除一个训练样本对模型的影响。删除一个训练样本$z$得到新的经验风险最小化为:(有一个疑问:这里为什么没有乘以1/(n-1)?)

但这样一个接一个地删除训练样本所花费的重训练代价太高。幸运地是,influence function提供了一个高效的近似:删除一个训练样本等价于增加训练样本$z$一个很小的值$\epsilon$。那么得到的新的参数为:

基于文献[3]得到新参数的影响为:

其中$H_{\hat{\theta}}$是Hessian矩阵且是正定的。(因为假设经验风险是二阶可微且严格凸的)


最后,使用链式法则得到在训练样本$z_{test}$影响的函数的closed-form表达式为:

ICML 2017 Best Paper理解_第1张图片

(2)disturb训练数据。

是为了观察修改训练样本的影响。假设一训练样本为$z = (x,y)$,得到新的修改样本为$z_{\delta} = (x + \delta, y)$。得到新的经验风险最小化为:


同样基于文献[3]得到新参数的影响为:

ICML 2017 Best Paper理解_第2张图片

如果$x$是连续的且$\delta$非常小,那么新参数的影响可近似为:


最后使用链式法则,得到disturb的影响为:

ICML 2017 Best Paper理解_第3张图片

三、优化

在计算(upweight和disturb训练数据都会涉及这个计算)


时,会遇到两个问题:一是需要求Hessian矩阵的逆和Hessian矩阵,对于n个训练样本的训练数据而言,其时间复杂度为$O(np^{3} + p^{3})$;二是需要计算所有样本的$I_{up,loss}(z_{i}, z_{test})$。这两个问题都会带来非常大的计算开销。

本文使用基于Hessian-vector products(HVPs)的两种方法近似计算解决上述两个问题:


第一种方法是conjugate gradients(CG)[4],第二种方法是stochastic estimation[5],详细介绍见论文。从而使计算所有样本点的时间复杂度被降低到$O(np+rtp)$。

四、验证与拓展

(1)损失函数与欧式距离的区别。


主要有两点区别:第一点是给更高训练损失的样本点更大的影响,认为异常值会对模型参数有更大的影响;第二点是权重协方差矩阵(即Hessian矩阵的逆)表明其它训练样本会对当前训练的样本产生影响。

(2)Influence function和留一法重训练的区别。

ICML 2017 Best Paper理解_第4张图片

直线是留一法的结果,蓝色点是influence function得到的结果。可以看到,两方法得到的结果是接近的,也表明了influence function是一个好的近似,并且Hessian求逆近似也取得了好的结果。

(3)非凸和不收敛。

即使得到的不是全局最小值,而是局部极小值,influence function也能获得有意义的结果。

(4)不可微的损失。

本文采用hinge作为损失函数(不可微),然后使用平滑的hinge损失。结果发现平滑后的结果和原结果十分接近。

五、实例

(1)理解模型行为。

influence function揭示了模型如何依赖于训练数据和从训练数据推断的。两个模型可以通过不同的方式做出相同的预测。本文使用Inception v3和使用RBF作为核函数的SVM两个模型做了实验。

(2)对抗训练样本。

关键观点是告诉了我们怎样修改训练样本使损失变化的最大。

(3)调试领域不匹配。

领域不匹配是指训练集的分布和测试集的分布不匹配,会造成模型在训练时有一个高准确度,但在测试集上准确度很低。因为influence function可以找到对这类错误影响最大的训练样本。

(4)纠正错误标签的样本。

标记出对模型训练影响最大的一些训练样本。

参考文献:

[1]Koh P W, Liang P. Understanding black-box predictions via influence functions[J]. arXiv preprint arXiv:1703.04730, 2017.

[2]Cook R D, Weisberg S. Characterizations of an empirical influence function for detecting influential cases in regression[J]. Technometrics, 1980, 22(4): 495-508.

[3] Cook R D, Weisberg S. Residuals and influence in regression[M]. New York: Chapman and Hall, 1982.

[4] Martens J. Deep learning via Hessian-free optimization[C]//ICML. 2010, 27: 735-742.

[5] Agarwal N, Bullins B, Hazan E. Second-order stochastic optimization in linear time[J]. stat, 2016, 1050: 15.

你可能感兴趣的:(ICML 2017 Best Paper理解)