Pytorch学习记录-使用Pytorch进行深度学习,60分钟闪电战02

首页.jpg

Autograd:自动求导

PyTorch中所有神经网络的核心是autograd包。我们首先简要地看看,然后训练第一个神经网络。
Autograd包为Tensors所有操作提供了自动求导。这是一个逐个运行的框架,这意味着用户的BP反向传递是由代码运行方式定义,并且每个迭代都不同。

Tensor

几个重要的类和参数
torch.Tensor是Autograd的核心类,如果设置参数.requires_grad为True,将会开始追踪Tensor的所有操作。当完成计算后,可以使用.backward()并自动获取所有梯度计算。Tensor的梯度将被汇集到.grad参数中。
如果想阻止Tensor追踪历史记录。

  • 可以调用.detach(),将其从历史记录中分离出来,并且能够防止将来的计算被追踪。
  • 可以调用torch.no_grad()来包裹代码,在评估模型时会有用,因为模型可能具有requires_grad = True的可训练参数,但我们不需要渐变。

还有一个重要的类,Function。

Pytorch学习记录-使用Pytorch进行深度学习,60分钟闪电战02_第1张图片
autograd

可以通过属性** .data 来访问原始的tensor,而关于这一Variable的梯度则集中于 .grad **属性中。
Tensor和Function相互连接并构建一个无环图,这个图能够编码一个完整的计算历史。每一个Tensor都有一个.grad_fn参数,这个参数可以引用已创建的Tensor中的Function(除非用户创建的 Tensor中参数.grad_fn为None)。

如果如果要计算导数,可以在Tensor上调用.backward()。如果Tensor是标量,(即它包含一个元素数据),则不需要为backward()指定任何参数,但是如果它有更多元素,则需要指定一个渐变参数,该参数是匹配形状的张量。

下面来试试grad_fn

#创建一个tensor,设定requires_grad=True来追踪计算
x=torch.ones(2,2,requires_grad=True)
print(x)

#进行tensor操作
y=x+2
print(y)
#查看y的grad_fn
print(y.grad_fn)
#完成更多的操作
z=y*y*3
out=z.mean()
print(z, z.grad_fn,out)

.requires_grad_(...) 能够更改现有的Tensor的requires_grad标志。如果没有给出,输入标志默认为False。

a = torch.randn(2, 2)
a = ((a * 3) / (a - 1))
print(a.requires_grad)
a.requires_grad_(True)
print(a.requires_grad)
b = (a * a).sum()
print(b.grad_fn)

梯度

开始BP反向传递,因为out包括了一个单独scalar,out.backward()和out.backward(torch.tensor(1.))相同。

out.forward()
print(x.grad)

你应该有一个含有4.5的矩阵。我们称之为“Tensor”。
在数学上,如果你有一个向量值函数,那么相对于的渐变是一个Jacobian matrix:一般来说,torch.autograd是一个Jacobian matrix的引擎。也就是说,给定任何向量,计算结果。如果恰好是标量函数的梯度,也就是说,那么通过链规则,Jacobian matrix乘积将是相对于的梯度:(注意,给出一个可以被视为列向量的行向量通过服用。)
上面这段在原教程中有点问题,总之就是pytorch包含有一个自动求导的核心功能,在Jacobian matrix同样适用。

现在让我们看一下Jacobian matrix的例子

x=torch.randn(3,requires_grad=True)
y=x*2
while y.data.norm()<1000:
    y=y*2

print(y)

v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(v)

print(x.grad)

print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

你可能感兴趣的:(Pytorch学习记录-使用Pytorch进行深度学习,60分钟闪电战02)