前言:前面的一篇文章中,已经很详细的说清楚了nn.Module、nn.functional、autograd.Function三者之间的联系和区别,虽然autograd.Function本质上是自定义函数的,但是由于神经网络、层、激活函数、损失函数本质上都是函数或者是多个函数的组合,所以使用autograd.Function依然可以达到定义层、激活函数、损失函数、甚至模型的目的,就像我们使用nn.Module是一样的,只不过更偏底层,稍微复杂一些而已,因为需要自己定义求导函数。
但是需要特别注意的是,对于复杂的层或者是网络,使用autograd.Function几乎是不可行的,因为我们需要重新定义反向求导规则即backward函数,而复杂层或者网络没办法写出每一个参数的导函数,或者是即便写出来也是异常复杂(因为链式求导法则再加上一些非线性函数的关系)所以一般不推荐使用autograd.Function去定义层,更不要去定义模型,但是一般定义一个较简单的函数还是可以的。
所以本文依然只会涉及到简单的定义操作,旨在帮助更好地理解autograd.Function的工作过程。关于autograd.Function的详细定义过程,我们可以参考前一篇文章:
pytorch的自定义拓展之(一)——torch.nn.Module和torch.autograd.Function
前面的文章都是使用的实例方法来重写forward和backward方法,下面来看一下如果使用静态方法,像Function类定义的那样,怎么实现。
1.1 重写Function类的静态方法
import torch
from torch.autograd import Function
# 类需要继承Function类,此处forward和backward都是静态方法
class MultiplyAdd(Function):
@staticmethod
def forward(ctx, w, x, b):
ctx.save_for_backward(w,x) #保存参数,这跟前一篇的self.save_for_backward()是一样的
output = w * x + b
return output
@staticmethod
def backward(ctx, grad_output): #获取保存的参数,这跟前一篇的self.saved_variables()是一样的
w,x = ctx.saved_variables
print("=======================================")
grad_w = grad_output * x
grad_x = grad_output * w
grad_b = grad_output * 1
return grad_w, grad_x, grad_b # backward输入参数和forward输出参数必须一一对应
x = torch.ones(1,requires_grad=True) # x 是1,所以grad_w=1
w = torch.rand(1,requires_grad=True) # w 是随机的,所以grad_x=随机的一个数
b = torch.rand(1,requires_grad=True) # grad_b 恒等于1
print('开始前向传播')
z=MultiplyAdd.apply(w, x, b) # forward,这里的前向传播是不一样的,这里没有使用函数去包装自定义的类,而是直接使用apply方法
print('开始反向传播')
z.backward() # backward
print(x.grad, w.grad, b.grad)
'''运行结果为:
开始前向传播
开始反向传播
=======================================
tensor([0.1784]) tensor([1.]) tensor([1.])
'''
注意:上面最大的不同除了使用的是静态方法以外,最大的不同在于,我没有使用一个函数去包装我的自定义类,而是直接使用了 z=MultiplyAdd.apply(w, x, b) 去完成前向运算过程,
这个apply方法是定义在Function类的父类_FunctionBase中定义的一个方法,但是这个方法到底是怎么实现的还不得而知。
这里到底是为什么,我还没有搞得特别清楚,因为我没有找到关于apply的详细代码所在,如果有大佬知道,望告知,万分感谢!
前面已经多次强调,虽然Function类可以用来定义模型,但是不要这么去做,因为Function类本身是为自定义函数而存在,我们在这里演示一下如何使用Function类自定义一个层,然后使用这个自定义的层来搭建网络。
为了简单,本文以搭建一个线性层作为实例说明:
参考下面的文章:
定义torch.autograd.Function的子类,自己定义某些操作,且定义反向求导函数