本文对大部分人来说可能仅仅起到科普的作用,因为Gumbel-Max仅在部分领域会用到,例如GAN、VAE等。笔者是在研究EMNLP上的一篇论文时,看到其中有用Gumbel-Softmax公式解决对一个概率分布进行采样无法求导的问题,故想到对Gumbel-Softmax做一个总结,由此写下本文
假设现在我们有一个离散随机变量 Z Z Z的分布
p 1 = p ( Z = 1 ) = π 1 p 2 = p ( Z = 2 ) = π 2 p 3 = p ( Z = 3 ) = π 3 . . . p x = p ( Z = x ) = π x p_1 = p(Z=1)=\pi_1\\ p_2 = p(Z=2) = \pi_2\\ p_3 = p(Z=3) = \pi_3\\ ...\\ p_x = p(Z=x) = \pi_x\\ p1=p(Z=1)=π1p2=p(Z=2)=π2p3=p(Z=3)=π3...px=p(Z=x)=πx
其中, ∑ i π i = 1 \sum_i \pi_i=1 ∑iπi=1。我们想根据 p 1 , p 2 , . . . , p x p_1,p_2,...,p_x p1,p2,...,px的概率采样得到一系列离散 z z z的值。但是这么做有一个问题,我们采样出来的 z z z只有值,没有生成 z z z的式子。例如我们要求 Z Z Z的期望,那么就有公式
E ( Z ) = p 1 + 2 p 2 + ⋯ + x p x \mathbb{E}(Z) = p_1 + 2p_2 + \cdots +xp_x E(Z)=p1+2p2+⋯+xpx
Z Z Z对 p 1 , p 2 , . . . , p x p_1,p_2,...,p_x p1,p2,...,px的导数都很清楚。但是现在我们的需求是采样一些具体的 z z z值,采样这个操作没有任何公式,因此也就无法求导。于是一个很自然的想法就产生了,我们能不能给一个以 p 1 , p 2 , . . . , p z p_1,p_2,...,p_z p1,p2,...,pz为参数的公式,让这个公式返回的结果是 z z z采样的结果呢?
一般来说 π i \pi_i πi是通过神经网络预测对于类别 i i i的概率,这在分类问题中非常常见,假设我们将一个样本送入模型,最后输出的概率分布为 [ 0.2 , 0.4 , 0.1 , 0.2 , 0.1 ] [0.2, 0.4,0.1,0.2,0.1] [0.2,0.4,0.1,0.2,0.1],表明这是一个5分类问题,其中概率最大的是第2类,到这一步,我们直接通过argmax就能获得结果了,但现在我们不是预测问题,而是一个采样问题。对于模型来说,直接取出概率最大的就可以了,但对我们来说,每个类别都是有一定概率的,我们想根据这个概率来进行采样,而不是直接简单无脑的输出概率最大的值
最常见的采样 z \mathbf{z} z的onehot公式为
z = onehot ( max { i ∣ π 1 + π 2 + ⋯ + π i − 1 ≤ u } ) (1) \mathbf{z} = \text{onehot}(\max \{i\mid \pi_1 + \pi_2+\cdots +\pi_{i-1} \leq u\})\tag{1} z=onehot(max{i∣π1+π2+⋯+πi−1≤u})(1)
其中 i = 1 , 2 , . . , x i=1,2,..,x i=1,2,..,x是类别的下标,随机变量 u u u服从均匀分布 U ( 0 , 1 ) U(0,1) U(0,1)
上面这个过程实际上是很巧妙的,我们将概率分布从前往后不断加起来,当加到 π i \pi_i πi时超过了某个随机值$ 0\leq u \leq 1 , 那 么 这 一 次 随 机 采 样 过 程 , ,那么这一次随机采样过程, ,那么这一次随机采样过程,z 就 被 随 机 采 样 为 第 就被随机采样为第 就被随机采样为第i$类,最后通过一个onehot变换
但是上述公式存在一个致命的问题:max函数是不可导的
Gumbel-Max技巧就是解决max函数不可导问题的,我们可以用argmax替换max,即
z = onehot ( argmax i { g i + log π i } ) (2) \mathbf{z} = \text{onehot}(\mathop{\text{argmax}}\limits_{i} \{g_i + \log \pi_i\})\tag{2} z=onehot(iargmax{gi+logπi})(2)
其中, g i = − log ( − log ( u i ) ) , u i ∼ U ( 0 , 1 ) g_i=-\log(-\log(u_i)), u_i \sim U(0,1) gi=−log(−log(ui)),ui∼U(0,1),这一项名为Gumbel噪声,或者叫Gumbel分布,目的是使得 z \mathbf{z} z的返回结果不固定
可以看到式 ( 2 ) (2) (2)的整个过程中,不可导的部分只有argmax,实际上我们可以用可导的softmax函数,在参数 τ \tau τ的控制下逼近argmax,最终 z i z_i zi的公式为
z i = exp ( g i + log π i τ ) ∑ j x exp ( g j + log π j τ ) (3) z_i = \frac{\exp(\frac{g_i + \log \pi_i}{\tau})}{\sum_{j}^x\exp(\frac{g_j + \log \pi_j}{\tau})}\tag{3} zi=∑jxexp(τgj+logπj)exp(τgi+logπi)(3)
其中, τ \tau τ越小 ( τ → 0 ) (\tau \to 0) (τ→0),整个softmax越光滑逼近argmax,并且 z = { z i ∣ i = 1 , 2 , . . . , x } \mathbf{z} = \{z_i\mid i=1,2,...,x\} z={zi∣i=1,2,...,x}也越接近onehot向量; τ \tau τ越大 ( τ → ∞ ) (\tau \to \infty) (τ→∞), z \mathbf{z} z向量越接近于均匀分布
整个过程相当于我们把不可导的取样过程,从 z \mathbf{z} z本身转移到了求 z \mathbf{z} z的公式中的一项 g i g_i gi中,而 g i g_i gi本身不依赖 p 1 , . . , p x p_1,..,p_x p1,..,px,所以 z z z对 p 1 , . . . , p x p_1,...,p_x p1,...,px就可以到了,而且我们得到的 z \mathbf{z} z仍然是离散概率分布的采样。这种采样过程转嫁的技巧有一个专有名词,叫重参数化技巧(Reparameterization Trick)