学习笔记【gumbel softmax】

gumbel softmax

  • 用于处理argmax不可导的情况

  • 解决思路:引入gumbel分布。在前向传播中使用argmax,后向梯度回传中使用gumbel_softmax计算

  • 代码

def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor:
    。。。
    gumbels = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    )  # ~Gumbel(0,1)
    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret

gumbel_softmax中引入了温度t, t越小,softmax就越接近One-hot。为了训练稳定性,一般t会取一个比较大的数字,然后逐步缩小。
内容转载自gumbel softmax

你可能感兴趣的:(学习笔记,学习,深度学习)