(个人理解仅供参考)
自己定义的网络结构,没有现成的,就得手写forward和backward
前向传播的表达式
求导结果
前向传播表达式:y = w * x + b
假设f()是我们关于y的loss函数,那么z = f(y)即为loss值
现在要求loss对w、x、b的偏导(假设只有一层):
dz/dx
= dz/dy
* dy/dx
= dz/dy
* w
dz/dw
= dz/dy
* dy/dw
= dz/dy
* x
dz/db
= dz/dy
* dy/db
= dz/dy
* 1
好在dz/dy
不用我们再求了,它就是 backward 的参数grad_output
。那么grad_output
是从哪来的呢?其实就是 forward 会 return output 给 backward ,至于 backward 怎么把 output 变为 grad_output 就不用细究了。
所以:
dz/dx
= grad_output
* w
dz/dw
= grad_output
* x
dz/db
= grad_output
* 1
因此,对于y = w * x + b,我们的代码为:
import torch
from torch.autograd import Function
class MultiplyAdd(Function):
@staticmethod
def forward(ctx, w, x, b):
ctx.save_for_backward(w, x) # 保存参数
output = w * x + b
return output # 传给backward
@staticmethod
def backward(ctx, grad_output):
w, x = ctx.saved_tensors
grad_w = grad_output * x
grad_x = grad_output * w
grad_b = grad_output * 1
return grad_w, grad_x, grad_b # 传给forward
Linear = MultiplyAdd.apply
"""
# 使用autograd.Function进行扩展的一个模板
class My_Function(Function):
def forward(self, inputs, parameters):
self.saved_for_backward = [inputs, parameters]
# output = [对输入和参数进行的操作,其实就是前向运算的函数表达式]
return output
def backward(self, grad_output):
inputs, parameters = self.saved_tensors # 或self.saved_variables
# grad_input = [求函数forward(input)关于 parameters 的导数,其实就是反向运算的导数表达式] * grad_output
return grad_input
"""
验证的话需要使用torch.autograd.gradcheck,给上我的完整代码,验证部分在最后:
import torch
from torch.autograd import Function, gradcheck
"""
# 使用autograd.Function进行扩展的一个模板
class My_Function(Function):
def forward(self, inputs, parameters):
self.saved_for_backward = [inputs, parameters]
# output = [对输入和参数进行的操作,其实就是前向运算的函数表达式]
return output
def backward(self, grad_output):
inputs, parameters = self.saved_tensors # 或者是self.saved_variables
# grad_input = [求函数forward(input)关于 parameters 的导数,其实就是反向运算的导数表达式] * grad_output
return grad_input
"""
class MultiplyAdd(Function):
@staticmethod
def forward(ctx, w, x, b):
ctx.save_for_backward(w, x)
output = w * x + b
return output
@staticmethod
def backward(ctx, grad_output):
w, x = ctx.saved_tensors
grad_w = grad_output * x
grad_x = grad_output * w
grad_b = grad_output * 1
return grad_w, grad_x, grad_b
Linear = MultiplyAdd.apply
x = torch.ones(1, requires_grad=True, dtype=torch.float64)
w = torch.rand(1, requires_grad=True, dtype=torch.float64)
b = torch.rand(1, requires_grad=True, dtype=torch.float64)
# print("start forward...")
# z = MultiplyAdd.apply(w, x, b)
# print("start backward...")
# z.backward()
#
# print(x.grad, w.grad, b.grad)
test = gradcheck(Linear, (x, w, b), eps=1e-6)
print(test)
现在我只是会用了这个,但是如果是两层的全连接层,这段代码是怎么工作的?这个问题我还没想明白,留个坑