torch.autograd.Function 自定义求导/反向传播方式

虽然pytorch可以自动求导,但是有时候一些操作是不可导的,这时候你需要自定义求导方式。也就是所谓的 "Extending torch.autograd"。

如果想要通过Function自定义一个操作,需要

①继承torch.autograd.Function这个类

from torch.autograd import Function
class LinearFunction(Function):

②实现forward()和backward()

属性(成员变量)
saved_tensors: 传给forward()的参数,在backward()中会用到。
needs_input_grad:长度为 :attr:num_inputs的bool元组,表示输出是否需要梯度。可以用于优化反向过程的缓存。
num_inputs: 传给函数 :func:forward的参数的数量。
num_outputs: 函数 :func:forward返回的值的数目。
requires_grad: 布尔值,表示函数 :func:backward 是否永远不会被调用。

成员函数
forward()
forward()可以有任意多个输入、任意多个输出,但是输入和输出必须是Variable。(官方给的例子中有只传入tensor作为参数的例子)
backward()
backward()的输入和输出的个数就是forward()函数的输出和输入的个数。其中,backward()输入表示关于forward()输出的梯度(计算图中上一节点的梯度),backward()的输出表示关于forward()的输入的梯度。在输入不需要梯度时(通过查看needs_input_grad参数)或者不可导时,可以返回None。

ctx is a context object that can be used to stash information for backward computation

ctx可以利用save_for_backward来保存tensors,在backward阶段可以进行获取

例1

torch.autograd.Function 自定义求导/反向传播方式_第1张图片z

import torch
from torch import nn
from torch.autograd import Function
import torch

class Exp(Function):
    
    @staticmethod
    def forward(ctx, input):
        result = torch.exp(input)
        ctx.save_for_backward(result)
        return result
    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result

x = torch.rand(4,3,5,5)
exp = Exp.apply # Use it by calling the apply method:
output = exp(x)
print(output.shape)

自定义的forward和backward要用静态方法,网上也有别的人写成def forward(self, input_):这种形式,但是这种写法快要被Pytorch淘汰了

例2

import torch
from torch import nn
from torch.autograd import Function
import torch

class MyReLU(Function):

    @staticmethod
    def forward(ctx, input_):
        # 在forward中,需要定义MyReLU这个运算的forward计算过程
        # 同时可以保存任何在后向传播中需要使用的变量值
        ctx.save_for_backward(input_)         # 将输入保存起来,在backward时使用
        output = input_.clamp(min=0)               # relu就是截断负数,让所有负数等于0
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # 根据BP算法的推导(链式法则),dloss / dx = (dloss / doutput) * (doutput / dx)
        # dloss / doutput就是输入的参数grad_output、
        # 因此只需求relu的导数,在乘以grad_output    
        input_, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input_ < 0] = 0                # 上诉计算的结果就是左式。即ReLU在反向传播中可以看做一个通道选择函数,所有未达到阈值(激活值<0)的单元的梯度都为0
        return grad_input

x = torch.rand(4,3,5,5)
myrelu = MyReLU.apply # Use it by calling the apply method:
output = myrelu(x)
print(output.shape)

例3

torch.autograd.Function 自定义求导/反向传播方式_第2张图片

import torch
from torch.autograd import Function
from torch.autograd import gradcheck

class LinearFunction(Function):
    # 创建torch.autograd.Function类的一个子类
    # 必须是staticmethod
    @staticmethod
    # 第一个是ctx,第二个是input,其他是可选参数。
    # ctx在这里类似self,ctx的属性可以在backward中调用。
    # 自己定义的Function中的forward()方法,所有的Variable参数将会转成tensor!因此这里的input也是tensor.在传入forward前,autograd engine会自动将Variable unpack成Tensor。
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias) # 将Tensor转变为Variable保存到ctx中
        output = input @ weight.t()
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output) #unsqueeze(0) 扩展处第0维
            # expand_as(tensor)等价于expand(tensor.size()), 将原tensor按照新的size进行扩展
        return output

    @staticmethod
    def backward(ctx, grad_output): 
        # grad_output为反向传播上一级计算得到的梯度值
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        # 分别代表输入,权值,偏置三者的梯度
        # 判断三者对应的Variable是否需要进行反向求导计算梯度
        if ctx.needs_input_grad[0]:
            grad_input = grad_output @ weight # 复合函数求导,链式法则
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t() @ input #复合函数求导,链式法则
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias

linear = LinearFunction.apply

# gradchek takes a tuple of tensor as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = torch.randn(20,20,requires_grad=True).double()
weight = torch.randn(20,20,requires_grad=True).double()
bias = torch.randn(20,requires_grad=True).double()
test = gradcheck(LinearFunction.apply, (input,weight,bias), eps=1e-6, atol=1e-4)
print(test)  # 没问题的话输出True

ctx.needs_input_grad作为一个boolean型的表示也可以用来控制每一个input是否需要计算梯度,e.g., ctx.needs_input_grad[0] = False,表示forward里的第一个input不需要梯度,若此时我们return这个位置的梯度值的话,为None即可

Function与Module的差异与应用场景

Function与Module都可以对pytorch进行自定义拓展,使其满足网络的需求,但这两者还是有十分重要的不同:

  • Function一般只定义一个操作,因为其无法保存参数,因此适用于激活函数、pooling等操作;Module是保存了参数,因此适合于定义一层,如线性层,卷积层,也适用于定义一个网络
  • Function需要定义三个方法:__init__, forward, backward(需要自己写求导公式);Module:只需定义__init__和forward,而backward的计算由自动求导机制构成

你可能感兴趣的:(Pytorch)