Gumbel Softmax数学证明及其应用解析

Gumbel Softmax数学证明及其应用解析

  • 论文地址:
  • 问题描述:
  • 具体实现:
  • 数学证明:
  • 代码实现:

论文地址:

https://arxiv.org/pdf/1611.01144.pdf%20http://arxiv.org/abs/1611.01144.pdf

问题描述:

考虑离散变量 x x x ,如果已知其分布向量 π = { π 1 , … , π N } \pi=\{\pi_1,\dots,\pi_N\} π={π1,,πN} ,则得到 x x x 的采样的一个简单采样方法是:

x π = o n e _ h o t ( arg max ⁡ i ( π i ) ) x_{\pi}=one\_hot (\argmax_i(\pi_i)) xπ=one_hot(iargmax(πi))

根据 s o f t m a x softmax softmax 函数,其中 arg ⁡ max ⁡ \arg \max argmax 这一操作取得 π n \pi_n πn 的概率为:

P ( π n ) = e π n ∑ j = 1 N e π j P(\pi_n)=\frac{e^{\pi_{n} }}{\sum_{j=1}^{N} e^{\pi_{j} }} P(πn)=j=1Neπjeπn

当我们在BP神经网络中,需要让采样结果可导的时候,这样简单的采样就行不通了。

原因在于 arg ⁡ max ⁡ \arg \max argmax 这一操作是不可导的,即并没有一个表达式可以映射 π \pi π z z z 上。

该技巧应用甚广,如深度学习中的各种 GAN、强化学习中的 A2C 和 MADDPG 算法等等。

只要涉及在离散分布上运用重参数技巧时(re-parameterization),都可以试试 Gumbel-Softmax Trick。

具体实现:

一般来说,对于 N N N 维概率向量 π \pi π,我们可以通过添加随机 Gumbel 噪声 G i G_i Gi 再取样:

x π = arg max ⁡ i ( ln ⁡ ( π i ) + G i ) x_{\pi}=\argmax_i \left(\ln \left(\pi_{i}\right)+G_{i}\right) xπ=iargmax(ln(πi)+Gi)

其中 G i G_i Gi 是独立同分布的标准 Gumbel 分布的随机变量。

我们重新看一下 Gumbel 分布,Gumbel 分布是一种极值型分布,它的概率密度函数(PDF)为:

f ( x ; μ , β ) = e − z − e − z , z = x − μ β f(x ; \mu, \beta)=e^{-z-e^{-z}}, z=\frac{x-\mu}{\beta} f(x;μ,β)=ezez,z=βxμ

公式中, μ \mu μ 是位置系数, β \beta β 是尺度系数,标准 Gumbel 分布中有: μ = 0 \mu=0 μ=0 β = 1 \beta=1 β=1

Gumbel Softmax数学证明及其应用解析_第1张图片
相应的,Gmubel 分布的累积密度函数(CDF)为:

F ( x ; μ , β ) = e − e − ( x − μ ) / β F(x ; \mu, \beta)=e^{-e^{-(x-\mu) / \beta}} F(x;μ,β)=ee(xμ)/β

并且我们易得它的反函数:

F − 1 ( y ; μ , β ) = μ − β ln ⁡ ( − ln ⁡ ( y ) ) F^{-1}(y ; \mu, \beta)=\mu-\beta \ln (-\ln (y)) F1(y;μ,β)=μβln(ln(y))

这样我们就可以通过从均匀分布中求逆得到 G i G_i Gi

G i = − ln ⁡ ( − ln ⁡ ( U i ) ) , U i ∼ U ( 0 , 1 ) G_{i}=-\ln \left(-\ln \left(U_{i}\right)\right), U_{i} \sim U(0,1) Gi=ln(ln(Ui)),UiU(0,1)

这就是 Gmubel-Max trick。

由于上述算法中 arg ⁡ max ⁡ \arg \max argmax 这一操作仍是不可导的,因此我们可以用两种方式来让该操作可导。

一种方式是使用Straight-Through Estimator 思想(例如 VQ-VAE 中使用的),重新设计采样为:

t π = s o f t m a x ( ln ⁡ ( π i ) + G i ) t_\pi=softmax(\ln \left(\pi_{i}\right)+G_{i}) tπ=softmax(ln(πi)+Gi)

z π = o n e _ h o t ( arg ⁡ max ⁡ ( t π ) ) z_{\pi}=one\_hot(\arg \max \left(t_\pi\right)) zπ=one_hot(argmax(tπ))

x π = t π + s g [ z π − t π ] x_\pi = t_\pi + sg[z_\pi-t_\pi] xπ=tπ+sg[zπtπ]

其中 s g sg sg 是 stop gradient 的意思,这样向前传播的时候使用的是采样的 z π z_{\pi} zπ,而反向传播则是使用 t π t_\pi tπ

但是 Gumbel Softmax 的作者指出并证明,直接用 s o f t m a x softmax softmax 函数来代替量化过程也是可行的,即:

x π = s o f t m a x ( ln ⁡ ( π i ) + G i ) x_{\pi}=softmax \left(\ln \left(\pi_{i}\right)+G_{i}\right) xπ=softmax(ln(πi)+Gi)

具体操作为:

  1. 对于 N N N 维概率向量 π \pi π,我们生成 N N N 个服从均匀分布 U ( 0 , 1 ) U(0,1) U(0,1) 的独立样本 U 1 , … , U N U_1,\dots,U_N U1,,UN
  2. 通过 G i = − ln ⁡ ( − ln ⁡ U i ) G_i=-\ln(-\ln U_i) Gi=ln(lnUi) 计算得到 G i G_i Gi
  3. 对应相加得到新的向量 z z z,其中 z i = π i + G i z_i=\pi_i+G_i zi=πi+Gi
  4. 通过 s o f t m a x softmax softmax 函数计算 x π x_\pi xπ,其中:

x i = e z i / τ ∑ j = 1 N e z j / τ x_i=\frac{e^{z_{i} / \tau}}{\sum_{j=1}^{N} e^{z_{j} / \tau}} xi=j=1Nezj/τezi/τ

前三步的目标是让新的随机变量 z z z 与原随机变量 π \pi π 相同,只需要证明取到 z n z_n zn 的概率跟取到 π n \pi_n πn 的概率相同,第四步则是使用温度参数 τ \tau τ 来控制采样结果的分布倾向:

  1. τ \tau τ 越小时,结果 x π x_\pi xπ 就越接近于 one-hot 分布;
  2. τ \tau τ 越大时,结果 x π x_\pi xπ 就越接近于均匀分布;

下面我们来证明。

数学证明:

证明取到 z n z_n zn 的概率跟取到 π n \pi_n πn 的概率相同可以写为:

P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = P ( π n ) P\left(z_{n} \geq z_{n^{\prime}} ; \forall n^{\prime} \neq n \mid\left\{\pi_{n^{\prime}}\right\}_{n^{\prime}=1}^{N}\right)=P(\pi_n) P(znzn;n=n{πn}n=1N)=P(πn)

也就是 z n z_{n} zn 比其他所有 z n ′ z_{n^{\prime}} zn 都大的概率为 P ( π n ) P(\pi_n) P(πn)

根据条件累积概率分布函数,我们可以得到:

P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = ∏ n ′ ≠ n P ( z n ≥ z n ′ ) P\left(z_{n} \geq z_{n^{\prime}} ; \forall n^{\prime} \neq n \mid\left\{\pi_{n^{\prime}}\right\}_{n^{\prime}=1}^{N}\right)=\prod_{n^{\prime} \neq n}P(z_n\geq z_n^{\prime}) P(znzn;n=n{πn}n=1N)=n=nP(znzn)

注意到, z n = π n + G n z_n=\pi_n+G_n zn=πn+Gn,并且 G n G_n Gn 服从 μ = 0 \mu=0 μ=0 β = 1 \beta=1 β=1 的标准 Gumbel 分布,那么将 π n \pi_n πn 看作常数时, z n z_n zn 服从 μ = π n \mu=\pi_n μ=πn β = 1 \beta=1 β=1 的标准 Gumbel 分布,它的 CDF 为:

F z n ( x ) = e − e − ( x − π n ) F_{z_n}(x)=e^{-e^{-(x-\pi_n)}} Fzn(x)=ee(xπn)

也就是:

F z n ′ ( x ) = e − e − ( x − π n ′ ) F_{z_n^{\prime}}(x)=e^{-e^{-(x-\pi_{n^{\prime}})}} Fzn(x)=ee(xπn)

那么根据 CDF 的定义,我们可得:

P ( z n ≥ z n ′ ) = P ( z n ′ ≤ z n ) = F z n ′ ( z n ) = e − e − ( z n − π n ′ ) P(z_n\geq z_n^{\prime})=P(z_n^{\prime}\leq z_n)=F_{z_n^{\prime}}(z_n)=e^{-e^{-(z_n-\pi_{n^{\prime}})}} P(znzn)=P(znzn)=Fzn(zn)=ee(znπn)

即:

P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = ∏ n ′ ≠ n e − e − ( z n − π n ′ ) P\left(z_{n} \geq z_{n^{\prime}} ; \forall n^{\prime} \neq n \mid\left\{\pi_{n^{\prime}}\right\}_{n^{\prime}=1}^{N}\right)=\prod_{n^{\prime} \neq n}e^{-e^{-(z_n-\pi_{n^{\prime}})}} P(znzn;n=n{πn}n=1N)=n=nee(znπn)

同时我们可得 z n z_n zn 分布的 CDF 为:

f z n ( x ) = e − ( x − π n ) − e − ( x − π n ) f_{z_n}(x)=e^{-(x-\pi_n)-e^{-(x-\pi_n)}} fzn(x)=e(xπn)e(xπn)

z n z_n zn 求积分可得边缘累积概率分布函数:

P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = ∫ P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) ⋅ f z n ( z n ) d z n \begin{aligned} P(z_{n} \geq z_{n^{\prime}} &; \forall n^{\prime} \neq n \mid\left\{\pi_{n^{\prime}}\right\}_{n^{\prime}=1}^{N})\\=& \int P\left(z_{n} \geq z_{n^{\prime}} ; \forall n^{\prime} \neq n \mid\left\{\pi_{n^{\prime}}\right\}_{n^{\prime}=1}^{N}\right) \cdot f_{z_{n}}\left(z_{n} \right) d z_{n} \end{aligned} P(znzn=;n=n{πn}n=1N)P(znzn;n=n{πn}n=1N)fzn(zn)dzn

带入 CDF 可得:

P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = ∫ ∏ n ′ ≠ n e − e − ( z n − π n ′ ) ⋅ e − ( z n − π n ) − e − ( z n − π n ) d z n \begin{aligned} P(z_{n} \geq z_{n^{\prime}} &; \forall n^{\prime} \neq n \mid\left\{\pi_{n^{\prime}}\right\}_{n^{\prime}=1}^{N})\\=& \int \prod_{n^{\prime} \neq n}e^{-e^{-(z_n-\pi_{n^{\prime}})}} \cdot e^{-(z_n-\pi_n)-e^{-(z_n-\pi_n)}} d z_{n} \end{aligned} P(znzn=;n=n{πn}n=1N)n=nee(znπn)e(znπn)e(znπn)dzn

化简可得:

P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = ∫ ∏ n ′ ≠ n e − e − ( z n − π n ′ ) ⋅ e − ( z n − π n ) − e − ( z n − π n ) d z n = ∫ e − ∑ n ′ ≠ n e − ( z n − π n ) − ( z n − π n ) − e − ( z n − π n ) d z n = ∫ e − ∑ n ′ = 1 N e − ( z n − π n ′ ) − ( z n − π n ) d z n = ∫ e − ( ∑ n ′ = 1 N e π n ′ ) e − z n − z n + π n d z n = ∫ e − e − z n + ln ⁡ ( ∑ n ′ = 1 N e π n ′ ) − z n + π n d z n = ∫ e − e − ( z n − ln ⁡ ( ∑ n ′ = 1 N e π n ′ ) ) − ( z n − ln ⁡ ( ∑ n ′ = 1 N e π n ′ ) ) − ln ⁡ ( ∑ n ′ = 1 N e π n ′ ) + π n d z n = e − ln ⁡ ( ∑ n ′ = 1 N e π n ′ ) + π n ∫ e − e − ( z n − ln ⁡ ( ∑ n ′ = 1 N e π n ′ ) ) − ( z n − ln ⁡ ( ∑ n ′ = 1 N e π n ′ ) ) d z n = e π n ∑ n ′ = 1 N e π n ′ ∫ e − e − ( z n − ln ⁡ ( ∑ n ′ = 1 N e π n ′ ) ) − ( z n − ln ⁡ ( ∑ n ′ = 1 N e π n ′ ) ) d z n = e π n ∑ n ′ = 1 N e π n ′ ∫ e − ( z n − ln ⁡ ( ∑ n ′ = 1 N e π n ′ ) ) − e − ( z n − ln ⁡ ( ∑ n ′ = 1 N e π n ′ ) ) d z n \begin{array}{l} P(z_{n} \geq z_{n^{\prime}} ; \forall n^{\prime} \neq n \mid\{\pi_{n^{\prime}}\}_{n^{\prime}=1}^{N})\\ =\int \prod_{n^{\prime} \neq n} e^{-e^{-(z_{n}-\pi_{n^{\prime}})}} \cdot e^{-(z_{n}-\pi_{n})-e^{-(z_{n}-\pi_{n})}} d z_{n}\\ =\int e^{-\sum_{n^{\prime} \neq n} e^{-(z_{n}-\pi_{n})}-(z_{n}-\pi_{n})-e^{-(z_{n}-\pi_{n})}} d z_{n}\\ =\int e^{-\sum_{n^{\prime}=1}^{N} e^{-(z_{n}-\pi_{n^{\prime}})}-(z_{n}-\pi_{n})} d z_{n}\\ =\int e^{-(\sum_{n^{\prime}=1}^{N} e^{\pi_{n^{\prime}}}) e^{-z_{n}}-z_{n}+\pi_{n}} d z_{n}\\ =\int e^{-e^{-z_{n}+\ln (\sum_{n^{\prime}=1}^{N} e^{\pi_{n^{\prime}}})}{-z_{n}+\pi_{n}}} d z_{n}\\ =\int e^{-e^{-(z_{n}-\ln (\sum_{n^{\prime}=1}^{N} e^{\pi_{n^{\prime}}}))}-(z_{n}-\ln (\sum_{n^{\prime}=1}^{N} e^{\pi_{n^{\prime}}}))-\ln (\sum_{n^{\prime}=1}^{N} e^{\pi_{n^{\prime}}} )+\pi_{n}} d z_{n}\\ =e^{-\ln (\sum_{n^{\prime}=1}^{N} e^{\pi_{n^\prime}})+\pi_{n}} \int e^{-e^{-(z_{n}-\ln (\sum_{n^{\prime}=1}^{N}e^{\pi_{n^{\prime}}} ))}-(z_{n}-\ln (\sum_{n^{\prime}=1}^{N} e^{\pi_{n^{\prime}}}))} d z_{n}\\ =\frac{e^{\pi_{n}}}{\sum_{n^{\prime}=1}^{N} e^{\pi_{n^{\prime}}}} \int e^{-e^{-(z_{n}-\ln (\sum_{n^{\prime}=1}^{N}e^{\pi_{n^\prime}}))}-(z_{n}-\ln (\sum_{n^{\prime}=1}^{N} e^{\pi_{n^{\prime}}})) }d z_{n}\\ =\frac{e^{\pi_{n}}}{\sum_{n^{\prime}=1}^{N} e^{\pi_{n^{\prime}}}} \int e^{-(z_{n}-\ln (\sum_{n^{\prime}=1}^{N} e^{\pi_{n^{\prime}}}))-e^{-(z_{n}-\ln (\sum_{n^{\prime}=1}^{N} e^{\pi_{n^{\prime}}}))}} d z_{n} \end{array} P(znzn;n=n{πn}n=1N)=n=nee(znπn)e(znπn)e(znπn)dzn=en=ne(znπn)(znπn)e(znπn)dzn=en=1Ne(znπn)(znπn)dzn=e(n=1Neπn)eznzn+πndzn=eezn+ln(n=1Neπn)zn+πndzn=ee(znln(n=1Neπn))(znln(n=1Neπn))ln(n=1Neπn)+πndzn=eln(n=1Neπn)+πnee(znln(n=1Neπn))(znln(n=1Neπn))dzn=n=1Neπneπnee(znln(n=1Neπn))(znln(n=1Neπn))dzn=n=1Neπneπne(znln(n=1Neπn))e(znln(n=1Neπn))dzn

注意到积分内为符合 μ = ln ⁡ ( ∑ k ′ = 1 K e x k ′ ) \mu=\ln \left(\sum_{k^{\prime}=1}^{K} e^{x_{k^{\prime}}}\right) μ=ln(k=1Kexk) 的 Gumbel 分布,所以积分的结果为 1 1 1,即:

P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = e π n ∑ n ′ = 1 N e π n ′ = P ( π n ) P(z_{n} \geq z_{n^{\prime}} ; \forall n^{\prime} \neq n \mid\{\pi_{n^{\prime}}\}_{n^{\prime}=1}^{N})=\frac{e^{\pi_{n} }}{\sum_{n^{\prime}=1}^{N} e^{\pi_{n^{\prime}} }}=P(\pi_n) P(znzn;n=n{πn}n=1N)=n=1Neπneπn=P(πn)

代码实现:

在 pytorch 中已经给出其实现:

def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor:

    if has_torch_function_unary(logits):
        return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim)
    if eps != 1e-10:
        warnings.warn("`eps` parameter is deprecated and has no effect.")

    gumbels = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    )  # ~Gumbel(0,1)
    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret

注意这里生成的 G i G_i Gi (对应代码中 gumbels)并没有采用论文中的方式,而是用了其等价方式,即直接从指数分布中采样(tensor.exponential_()),然后再取负对数。

这是因为 − ln ⁡ U i , U ∼ U ( 0 , 1 ) -\ln U_i,U\sim U(0,1) lnUi,UU(0,1) 的分布符合指数分布,证明如下:

Y = − l n ( X ) Y=-ln(X) Y=ln(X),且 X ∼ U ( 0 , 1 ) X\sim U(0,1) XU(0,1),有:

F y ( Y ) = P ( Y < y ) = P ( − ln ⁡ x < y ) = P ( x > e − y ) = 1 − P ( x ≤ e − y ) = 1 − F x ( e − y ) = 1 − e − y \begin{array}{l} F_y(Y)=P(Ye^{-y})=1-P(x\leq e^{-y})=1-F_x(e^{-y})=1-e^{-y} \end{array} Fy(Y)=P(Y<y)=P(lnx<y)=P(x>ey)=1P(xey)=1Fx(ey)=1ey

对其求导可得,其概率密度函数为:

f Y ( y ) = e − y f_Y(y)=e^{-y} fY(y)=ey

刚好为指数分布。

你可能感兴趣的:(概率论,机器学习,深度学习,人工智能,transformer)