Pytorch 可微分round函数

round函数在定义域中的导数,处处为0或者无穷,梯度无法反向传播。本文将使用autograd.function类自定义可微分的round函数,使得round前后的tensor,具有相同的梯度。

from torch.autograd import Function


class BypassRound(Function):
  @staticmethod
  def forward(ctx, inputs):
    return torch.round(inputs)

  @staticmethod
  def backward(ctx, grad_output):
    # 这里的grad_output是round之后的tensor的梯度,直接将它作为round之前tensor的梯度
    return grad_output


# Function.apply的别名
bypass_round = BypassRound.apply

# demo
z3_rounded = bypass_round(z3)

 具体原理和细节参考以下博客:

定义torch.autograd.Function的子类,自己定义某些操作,且定义反向求导函数_tsq292978891的博客-CSDN博客_saved_tensors

2022.4.7更新:更简单的方法如下

def ste_round(x):
    return torch.round(x) - x.detach() + x

torch.round(x)导数处处为0,x.detach()在计算图中,x的导数为1

因此:ste_round(x)的梯度 == x的梯度

你可能感兴趣的:(科学炼丹,pytorch,深度学习)