【Gumbel-softmax 采样】再参数化

本方法来源于论文:A Novel Attribute Reconstruction Attack in Federated Learning 。

Gumbel-max 和 Gumbel-softmax 都是一种在参数化的采样方法,在离散变量的采样中具有:将某一随机离散变量X变得对每一维度概率可导的 作用。

为什么要 Gumbel-max(softmax)?

Gumbel-Max解决了这么一个问题:
我们知道一个离散随机变量X的分布,比如说p(X=1)=p1=0.2,p(X=2)=p2=0.3,p(X=3)=p2=0.5,然后我们想得到一些服从这个分布的离散的x的值。我们一般的思路当然是,就按照这个概率去采样嘛,采样一些x来用就行了。

但是这么做有一个问题:我们采样出来的x只有值,没有生成x的式子。本来x的值和p1,p2,p3是相关的,但是我们使用采样这么一个办法之后,我们得到的x没有办法对p1,p2,p3求导这在神经网络里面就是一个大问题,没法BP了嘛。很多时候我们只是要x的期望,那么我们就是x=p1+2p2+3p3,x对p1,p2,p3的导数都很清楚,逆向传播很好实现。但是我们这里的需求是采样,要得到一些实际的x值,就像上面说的,不能求导的问题就来了。

那么,能不能给一个以p1,p2,p3为参数的公式,让这个公式返回的结果是x的采样呢?这样的话,我们就可以对这个公式求导,从而得到采样的x对p1,p2,p3的导数了!答案当然是:能!

我们所想要的就是下面这个式子,即gumbel-max技巧:

在这里插入图片描述

其中 在这里插入图片描述
这一项名叫Gumbel噪声,这个噪声是用来使得z的返回结果不固定的(每次都固定一个值就不叫采样了嘛)。最终我们得到的z向量是一个one_hot向量,用这个向量乘一下x的值域向量,得到的就是我们要采样的x

可以看到,上面从p生成x的整个过程里面,不可导的函数只有argmax,于是我们用可导的softmax代替一下这里的argmax函数,问题完全解决。最终得到的z向量为:
在这里插入图片描述
这个式子里的参数 T 越小,z越接近one_hot向量。然后我们得到了一些可以对p求导的x的取样值,当然因为我们最后用的是softmax,所以x的值跟纯粹的取样也不完全一样,但比起直接求期望,我们至少得到了样本,不是吗?

这个过程相当于我们把不可导的取样过程,从x本身转嫁到了求取x的公式中的一项g上面,而g不依赖于p1,p2,p3。这样一来,x对p1,p2,p3仍然是可导的,而我们得到的x仍然是离散值的采样。目标达成。这样的采样过程转嫁的技巧有一个专门的名字,叫再参化技巧(reparameterization trick),有兴趣的同学可以去搜一下。

你可能感兴趣的:(方法,机器学习,概率论,深度学习)