贝叶斯神经网络
1 人工神经网络的作用与局限
人工神经网络(artificial neural network,缩写ANN),简称神经网络(neural network,缩写NN)或类神经网络,是一种模仿生物神经网络(动物的中枢神经系统,特别是大脑)的结构和功能的数学模型或计算模型,主要用于对函数进行估计或近似。
在神经网络的训练中,就是训练网络中的参数以实现预测的结果如下所示
y p r e d i c t = W T ∗ x + b y_{predict} = W^T * x +b ypredict=WT∗x+b
对于一个神经网络来说,最为核心的是如何根据训练集的数据,得到各层的模型参数,使得Loss最小,因其强大的非线性拟合能力而在各个领域有着重要应用。而其问题是在数据量较少的情况下存在严重的过拟合现象,对于获得数据代价昂贵的一些课题比如车辆控制等领域,应用存在局限性。
2 贝叶斯网络简介
贝叶斯神经网络(BNN)不同于一般的神经网络,其权重参数是随机变量,而非确定的值。贝叶斯神经网络把权重看成是服从均值为 μ \mu μ ,方差为 δ \delta δ 的高斯分布,每个权重服从不同的高斯分布,反向传播网络优化的是权重,贝叶斯神经网络优化的是权重的均值和方差。BNN把概率建模和神经网络结合起来,并能够给出预测结果的置信度。
2.1 BNN模型
假设BNN的网络参数为 W W W, p ( W ) p(W) p(W) 是参数的先验分布,给定观测数据 D = X , Y D = X,Y D=X,Y, 这里 X X X是输入数据, Y Y Y 是标签数据。BNN 希望给出以下的分布:
网络的预测值为:
P ( Y ∣ X , D ) = ∫ P ( Y ∣ X , W ) P ( W ∣ D ) d W P(Y|X,D)=\int P(Y|X,W)P(W|D) dW P(Y∣X,D)=∫P(Y∣X,W)P(W∣D)dW (1)
由于
W W W 是随机变量,因此,我们的预测值也是个随机变量。
P ( Y ∣ X , W ) P(Y|X,W) P(Y∣X,W) 表示给定权重
W W W 和输入
X X X,输出
Y Y Y的概率分布,其实就是神经网络。我们只需要依据训练集
D D D 建模出权重的分布
P ( W ∣ D ) P(W|D) P(W∣D),就可以依据蒙特卡罗方法,采样
m m m个服从
P ( W ∣ D ) P(W|D) P(W∣D)分布的样本,计算
1 m ∑ i = 1 m p ( Y ∣ X , W i ) \frac{1}{m}\sum_{i=1}^mp(Y|X,W_i) m1∑i=1mp(Y∣X,Wi) ,即可得到
p ( Y ∣ X , D ) p(Y|X,D) p(Y∣X,D) 。
其中:
P ( W ∣ D ) = P ( W ) P ( D ∣ W ) P ( D ) P(W|D)=\frac{P(W)P(D|W)}{P(D)} P(W∣D)=P(D)P(W)P(D∣W) (2)
P ( W ∣ D ) P(W|D) P(W∣D) 是后验分布,
P ( D ∣ W ) P(D|W) P(D∣W) 是似然函数,
P ( D ) P(D) P(D) 是边缘似然。
2.2 基于变分推断的BNN训练
上述公式 ( 1 ) (1) (1) 说明用NN 对数据进行概率建模并预测的核心在于做高效近似后验推断。如果直接采样后验概率 P ( W ∣ D ) P(W|D) P(W∣D) 来评估 p ( Y ∣ X , D ) p(Y|X,D) p(Y∣X,D) 的话,存在后验分布多维的问题,而变分推断的思想是使用简单分布去近似后验分布。
核心思想是利用一个分布利用一个分布 q ( W ∣ θ ) q(W|\theta) q(W∣θ) 来逼近 p ( W ∣ D ) p(W|D) p(W∣D) ,利用KL散度度量 q ( W ∣ θ ) q(W|\theta) q(W∣θ) 、 p ( W ∣ D ) p(W|D) p(W∣D)两个分布之间的相似性。 其中 θ = ( μ , δ ) \theta = (\mu,\delta) θ=(μ,δ), 表示每个权重 w i w_i wi 从正态分布 ( μ i , δ i ) (\mu_i,\delta_i) (μi,δi) 中采样。 希望 q ( W ∣ θ ) q(W|\theta) q(W∣θ) 和 p ( W ∣ D ) p(W|D) p(W∣D) 距离最小,也就是优化:
θ ∗ = arg min θ K L [ q ( W ∣ θ ) ∣ ∣ p ( W ∣ D ) ] \theta^*=\argmin_{\theta} KL[q(W|\theta)||p(W|D)] θ∗=θargminKL[q(W∣θ)∣∣p(W∣D)] (3)
进一步推导:
θ ∗ = arg min θ K L [ q ( W ∣ θ ) ∣ ∣ p ( W ∣ D ) ] \theta^*=\argmin_{\theta} KL[q(W|\theta)||p(W|D)] θ∗=θargminKL[q(W∣θ)∣∣p(W∣D)]
= arg min θ E q ( W ∣ θ ) [ l o g [ q ( W ∣ θ ) p ( W ∣ D ) ] ] = \argmin_{\theta} E_{q(W|\theta)}[log[\frac{q(W|\theta)}{p(W|D)}]] =θargminEq(W∣θ)[log[p(W∣D)q(W∣θ)]] (KL散度的定义)
= arg min θ E q ( W ∣ θ ) [ l o g [ q ( W ∣ θ ) P ( D ) P ( D ∣ W ) P ( W ) ] ] = \argmin_{\theta} E_{q(W|\theta)}[log[\frac{q(W|\theta)P(D)}{P(D|W)P(W)}]] =θargminEq(W∣θ)[log[P(D∣W)P(W)q(W∣θ)P(D)]](贝叶斯理论)
= arg min θ E q ( W ∣ θ ) [ l o g [ q ( W ∣ θ ) P ( D ∣ W ) P ( W ) ] ] = \argmin_{\theta} E_{q(W|\theta)}[log[\frac{q(W|\theta)}{P(D|W)P(W)}]] =θargminEq(W∣θ)[log[P(D∣W)P(W)q(W∣θ)]](P(D)不依赖于
θ \theta θ,消去) (4)
公式中, q ( W ∣ θ ) q(W|\theta) q(W∣θ) 表示给定正态分布的参数后,权重参数的分布。 P ( D ∣ W ) P(D|W) P(D∣W) 表示给定网络参数后,观测数据的似然; P ( W ) P(W) P(W) 表示权重的先验。
使用
L = − E q ( W ∣ θ ) [ l o g [ q ( W ∣ θ ) P ( D ∣ W ) P ( W ) ] ] L=- E_{q(W|\theta)}[log[\frac{q(W|\theta)}{P(D|W)P(W)}]] L=−Eq(W∣θ)[log[P(D∣W)P(W)q(W∣θ)]] (5)
来表示变分下界ELBO, 也就是公式(4)等价于最大化ELBO:
L = ∑ i l o g q ( w i ∣ θ i ) − ∑ i l o g P ( w i ) − ∑ j l o g P ( y j ∣ w , x j ) L=\sum_{i}log q(w_i|\theta_i) - \sum_ilogP(w_i) - \sum_j logP(y_j| w,x_j) L=∑ilogq(wi∣θi)−∑ilogP(wi)−∑jlogP(yj∣w,xj) (6)
其中, D = ( x , y ) D={(x,y)} D=(x,y).
我们需要对公式(4)中的期望进行求导,但是,这里,我们使用对权重进行重参数的技巧:
w i = μ i + σ i × ϵ i w_i = \mu_i +\sigma_i\times \epsilon_i wi=μi+σi×ϵi (7)
其中,
ϵ i ∼ N ( 0 , 1 ) \epsilon_i\sim N(0,1) ϵi∼N(0,1)。
于是,用
ϵ \epsilon ϵ 代替
w w w 后有:
∂ ∂ θ E q ( ϵ ) [ l o g [ q ( W ∣ θ ) P ( D ∣ W ) P ( W ) ] ] = E q ( ϵ ) [ ∂ ∂ θ l o g [ q ( W ∣ θ ) P ( D ∣ W ) P ( W ) ] ] \frac{\partial}{\partial \theta}E_{q(\epsilon)}[log[\frac{q(W|\theta)}{P(D|W)P(W)}]] = E_{q(\epsilon)}[\frac{\partial}{\partial \theta} log[\frac{q(W|\theta)}{P(D|W)P(W)}]] ∂θ∂Eq(ϵ)[log[P(D∣W)P(W)q(W∣θ)]]=Eq(ϵ)[∂θ∂log[P(D∣W)P(W)q(W∣θ)]] (8)
也就是说,我们可以通过 多个不同的
ϵ i ∼ N ( 0 , 1 ) \epsilon_i\sim N(0,1) ϵi∼N(0,1) ,求取
∂ ∂ θ l o g [ q ( W ∣ θ ) P ( D ∣ W ) P ( W ) ] \frac{\partial}{\partial \theta} log[\frac{q(W|\theta)}{P(D|W)P(W)}] ∂θ∂log[P(D∣W)P(W)q(W∣θ)] 的平均值,来近似 KL 散度对
θ \theta θ 的求导。
此外,除了对
w w w 进行重采样之外,为了保证
θ \theta θ 参数取值范围包含这个实轴,对
δ \delta δ 进行重采样,可以令,
δ = l o g ( 1 + e ρ ) \delta=log(1+e^{\rho}) δ=log(1+eρ) (9)
然后, θ = ( μ , ρ ) \theta = (\mu,\rho) θ=(μ,ρ).
2.3 BNN 算法流程
- 从 N ( μ , l o g ( 1 + e ρ ) ) N(\mu, log(1+e^{\rho})) N(μ,log(1+eρ)) 中采样, 获得 w w w;
- 分别计算 l o g q ( w ∣ θ ) logq(w|\theta) logq(w∣θ)、 l o g p ( w ) log p(w) logp(w)、 l o g p ( y ∣ w , x ) . log p(y|w,x). logp(y∣w,x). 其中,计算 l o g p ( y ∣ w , x ) log p(y|w,x) logp(y∣w,x) 实际计算 l o g p ( y ∣ y p r e ) , y p r e = w ∗ x log p(y|y_{pre}),y_{pre}=w*x logp(y∣ypre),ypre=w∗x。也就可以得到
L = ∑ i l o g q ( w i ∣ θ i ) − ∑ i l o g P ( w i ) − ∑ j l o g P ( y j ∣ w , x j ) L=\sum_{i}log q(w_i|\theta_i) - \sum_ilogP(w_i) - \sum_j logP(y_j| w,x_j) L=∑ilogq(wi∣θi)−∑ilogP(wi)−∑jlogP(yj∣w,xj) (6)
- 更新参数 θ ′ = θ − α ▽ θ L \theta' = \theta - \alpha \bigtriangledown_\theta L θ′=θ−α▽θL.