自定义操作torch.autograd.Function

pytorch可以自动求导,但是有时候一些操作是不可导的,这时候你需要自定义求导方式。也就是所谓的 “Extending torch.autograd”。

Function与Module的差异与应用场景

Function与Module都可以对pytorch进行自定义拓展,使其满足网络的需求,但这两者还是有十分重要的不同:

  1. Function一般只定义一个操作,因为其无法保存参数,因此适用于激活函数、pooling等操作;Module是保存了参数,因此适合于定义一层,如线性层,卷积层,也适用于定义一个网络
  2. Function需要定义三个方法:init, forward,
    backward(需要自己写求导公式);Module:只需定义__init__和forward,而backward的计算由自动求导机制构成

可以不严谨的认为,Module是由一系列Function组成,因此其在forward的过程中,Function和Variable组成了计算图,在backward时,只需调用Function的backward就得到结果,因此Module不需要再定义backward。
Module不仅包括了Function,还包括了对应的参数,以及其他函数与变量,这是Function所不具备的

假设你现在想自定义一个操作(一个类,假设名字叫LinearFunction),那么就按顺序做下面几件事就好:

1. 首先你要让它继承这个class:torch.autograd.Function。

from torch.autograd import Function
class LinearFunction(Function):

2. 同时,实现这2个函数:

forward():执行这个操作的代码。需要定义LinearFunction这个运算的forward计算过程,同时可以保存任何在后向传播中需要使用的变量值。输出类型是Tensor,或者是Tensor组成的tuple。假设LinearFunction这个运算的输入有 N I N_I NI个,输出有 N O N_O NO个,则forward()的输入有 N I N_I NI 个,输出有 N O N_O NO个。

1 Performs the operation.
2 This function is to be overridden(覆写) by all subclasses.
3 It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
4 The context can be used to store tensors that can be then retrieved during the backward pass.

backward():计算导数的代码。假设LinearFunction这个运算的输入有 N I N_I NI 个,输出有 N O N_O NO个,则backward()的输入有 N O N_O NO个,输出有 N I N_I NI个。代表输出对这 N I N_I NI 个输入的导数。needs_input_grad是一个boolean值组成的元组,代表每个input是否需要求导数。

1 Defines a formula for differentiating the operation.
2 This function is to be overridden by all subclasses.
3 It must accept a context ctx as the first argument, followed by as many outputs did forward() return, and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.
4 The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computated w.r.t. the output.

注意事项:

forward()函数的输入参数第1个是ctx,第2个是input,其他是可选参数。

save_for_backward(*tensors)

保存给定的张量,以备将来调用backward(),最多调用1次,并且只能从forward()方法内部调用。以后,可以通过saved_tensors属性访问已保存的张量。

mark_dirty(*args)

将给定张量标记为已经in-place operation.。仅应从forward()方法内部调用1次,并且所有参数都应作为输入。

mark_non_differentiable(*args)

将输出标记为不可微分。仅应从forward()方法内部调用1次,并且所有参数都应该是输出。这会将输出标记为不需要梯度,从而提高了backward计算的效率。
  默认情况下,所有可微分类型的输出张量将设置为require gradient,如果您不希望它们要求可微分,则可以使用上面提到的mark_non_differentiable方法。对于非可微类型(例如整数类型)的输出张量,它们不会被标记为requiring gradients。

例1:自定义操作Exp

class Exp(Function):

    @staticmethod
    def forward(ctx, i):
        result = i.exp()
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result

#Use it by calling the apply method:

output = Exp.apply(input)

注释:
forward()和backward()都应该是staticmethod。

forward()的输入只有2个(ctx, i),ctx必须有,i是input。
ctx.save_for_backward(result)表示forward()的结果要存起来,以后给backward()。

backward()的输入只有2个(ctx, grad_output),ctx必须有,grad_output是最终object对的forward()输出的导数。
result, = ctx.saved_tensors得到之前forward()存的结果。
因为 e x e^x ex的导数还是 e x e^x ex ,所以return grad_output (上一层求导结果)* result( e x e^x ex )。

调用时,直接 Exp.apply(input)即可。

例2:自定义操作LinearFunction

Inherit from Function

class LinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

grad_input代表object对这个操作的input的导数。
grad_weight代表object对这个操作的weight的导数。
grad_bias代表object对这个操作的bias的导数。
调用时,直接 linear = LinearFunction.apply(input) 即可。

例3:自定义操作MyReLU

class MyReLU(Function):

    @staticmethod
    def forward(ctx, input_):
        # 在forward中,需要定义MyReLU这个运算的forward计算过程
        # 同时可以保存任何在后向传播中需要使用的变量值
        ctx.save_for_backward(input_)         # 将输入保存起来,在backward时使用
        output = input_.clamp(min=0)               # relu就是截断负数,让所有负数等于0
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # 根据BP算法的推导(链式法则),dloss / dx = (dloss / doutput) * (doutput / dx)
        # dloss / doutput就是输入的参数grad_output、
        # 因此只需求relu的导数,在乘以grad_output    
        input_, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0                # 上诉计算的结果就是左式。即ReLU在反向传播中可以看做一个通道选择函数,所有未达到阈值(激活值<0)的单元的梯度都为0
        return grad_input

由非Tensor参数参数化的函数的另一个示例:

 @staticmethod
    def forward(ctx, tensor, constant):
        # ctx is a context object that can be used to stash information
        # for backward computation
        ctx.constant = constant
        return tensor * constant

    @staticmethod
    def backward(ctx, grad_output):
        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None

你可能想要检查实现的向后方法是否实际计算了函数的导数。通过与使用较小有限差的数值近似进行比较,可以实现:

from torch.autograd import gradcheck
 gradcheck takes a tuple of tensors as input, check if your gradient
 evaluated with these tensors are close enough to numerical
 approximations and returns True if they all verify this condition.
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)

你可能感兴趣的:(pytorch,深度学习,人工智能)