本文阅读时长大约5分钟。
机器学习及深度学习中,梯度求导这个操作无处不在,pytorch为了使用方便,将梯度求导的方法包含在torch.autograd类中。
在一般的梯度求解中,使用其中的 torch.autograd.grad(y,x,retain_graph),表示求y对x的偏导,具体代码如下:
x = torch.tensor([3.], requires_grad=True)
y = torch.pow(x, 2) # y = x**2
grad_1 = torch.autograd.grad(y, x, create_graph=True) # grad_1 = dy/dx = 2x = 2 * 3 = 6
print(grad_1)
grad_2 = torch.autograd.grad(grad_1[0], x) # grad_2 = d(dy/dx)/dx = d(2x)/dx = 2
print(grad_2)
结果即为:
(tensor([6.], grad_fn=<MulBackward0>),)
(tensor([2.]),)
特殊的,为了适用于模型训练过程中的反向传播算法,torch.autograd.backward(variables, gradient, retain_graph) 经常会被使用。其中variables是被求积分的叶子节点;gradient是对应variable的梯度,仅当variable不是标量且需要求梯度的时候使用;retain_graph是否保存计算图。
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x) # retain_grad()
b = torch.add(w, 1)
y0 = torch.mul(a, b) # y0 = (x+w) * (w+1) dy2/dw = 5
y1 = torch.add(a, b) # y1 = (x+w) + (w+1) dy1/dw = 2
loss = torch.cat([y0, y1], dim=0) # [y0, y1]
grad_tensors = torch.tensor([1., 2.]) # 求梯度时是 5*1 + 2*2 = 9,对w求导
loss.backward(gradient=grad_tensors) # gradient 传入 torch.autograd.backward()中的grad_tensors
print(w.grad)
输出结果是:
9
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
y.backward(retain_graph=True)
print(w.grad)
y.backward()
结果为:
tensor([5.])
若中间的y.backward(retain_graph=False),则会报错
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
这一点对于pytorch入门来说是必须要熟悉的,pytorch为了获得更大的样本训练容量,梯度是会累加的,如下代码所示:
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
for i in range(3):
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
y.backward()
print(w.grad)
显示结果如下,
tensor([5.])
tensor([10.])
tensor([15.])
所求得的梯度会累加,在上面代码for循环中补上最后一句,w.grad.zero_(),则每次代码都清零,显示结果为:
tensor([5.])
tensor([5.])
tensor([5.])
即回归正常。关于pytorch的梯度采用累加操作,每次迭代一个循环后清零,简单来讲是pytorch特有的,有扩大batchsize的好处及减小内存消耗的好处[2],具体见参考链接。
参考自: