点击关注我哦
神经网络与反向传播
从数学角度上来说,神经网络就是经过训练得到所需结果的一个复杂的数学函数。反向传播是神经网络的重要概念,主要根据链式法则计算损失Loss对输入权重w的梯度(偏导数),然后使用学习率更新权重值,以总体上减少损失。
创建和训练神经网络一般包含以下5个步骤:
1. 定义网络结构;
2. 使用输入数据在该网络结构上进行正向传播;
3. 计算损失Loss;
4. 反向传播计算每个权重的梯度;
5. 根据学习率更新权重值;
整个训练过程是通过迭代的方式完成的,对于每次迭代过程,都会计算几个梯度,并建立一个计算图来存储这些梯度函数。(PyTorch构建动态计算图来实现,TensorFlow构建静态计算图实现)。不同于静态图需要先定义计算图再进行使用,动态图是在每次迭代中从头开始构建的,为梯度计算提供了较大的灵活性。
动态计算图
PyTorch的计算图中,其实只包含了两种元素:张量(变量对象)和运算(函数操作)。其中张量可以分为叶子节点和非叶子节点;运算也就是我们常见的加减乘除、三角函数等可求导的运算。下图是一个将两个张量x和y进行相乘的计算图图示:
图中绿色框表示变量对象,紫色框表示运算操作
上节我们有提到每个变量对象都有8个属性,此处根据图中内容只解释其中常用的5个属性:
data
保存变量数据的属性。例如:上图中变量x持有一个1x1张量,其值为1.0,而变量y持有同样持有一个1x1张量,其值为2.0。z为两个变量的乘积,即2.0。
require_grad
如果需要被计算梯度,则设置该属性为True;否则该属性为False。对于任意Tensor a可以按如下方式进行修改:a.requires_grad_(True)。
grad
grad用来保存梯度值,如果require_grad为False,则将保留None值。在第一次调用.backward()方法后,该属性将获取一个属性值。例如,如果对含有x的变量out调用out.backward(),则x.grad将保存∂out/∂x。并且该属性值将在每次调用.backward()方法时进行累加,因此在训练网络时,每次计算backward()之前需要将前一时刻梯度清零的原因。
注:Tensor的grad属性为None和require_grad为False并不等价,也就是Tensor的require_grad可以为True,但是grad同样可以为None。
grad_fn
grad_fn用来记录反向传播的梯度函数为何种类型,一般叶子节点为None,结果节点的grad_fn才有效。
is_leaf
该节点为叶子节点时is_leaf属性为True,反之则为False。
一般满足以下条件的节点均为叶子节点:
1. 用户自己创建的变量,比如x = torch.tensor(1.0)或x = torch.randn(1,1);
2. 进行运算的所有的张量的require_grad 属性均为False;
3. 某个张量使用.detach()方法将一个非叶子节点剥离成叶子节点。
注:在调用backward()方法时,仅对require_grad和is_leaf属性均为True的节点进行计算。
当设置require_grad = True时, PyTorch将开始进行梯度跟踪并在每个步骤中进行,如下图所示:
在PyTorch中,可以使用下述代码生成上述计算图:
import torch
# 创建计算图
x = torch.tensor(1.0, requires_grad = True)
y = torch.tensor(2.0)
z = x * y
# 输出各变量的属性值
for i, name in zip([x, y, z], "xyz"):
print(f"{name}\ndata: {i.data}\nrequires_grad: {i.requires_grad}\n\
grad: {i.grad}\ngrad_fn: {i.grad_fn}\nis_leaf: {i.is_leaf}\n")
若不想让PyTorch进行梯度跟踪及形成计算图时,可以使用with torch.no_grad()上下文管理器,代码的运行速度也会变快。
import torch
# 创建计算图
x = torch.tensor(1.0, requires_grad = True)
print(x.requires_grad)
y = x * 2
print(y.requires_grad)
with torch.no_grad():
y = x * 2
print(y.requires_grad)
上述代码的输出结果为:
True
True
False
· END ·
RECOMMEND
推荐阅读
1. 效率提升的软件大礼包
2. 深度学习——入门PyTorch(一)
3. 深度学习——入门PyTorch(二)
4. PyTorch入门——autograd(一)