梯度是如何在Pytorch中传递的

在Pytorch中,传入网络中计算的数据类型必须是Tensor类型,如果requires_grad = True的话,就会保存着梯度和创建这个Tensor的function的引用,换句话说,就是记录网络每层的梯度和网络图,可以实现梯度的反向传播,网络图可以表示如下(来自Deep Learning with PyTorch: A 60 Minute Blitz):

i n p u t → c o n v 2 d → r e l u → m a x p o o l 2 d → c o n v 2 d → r e l u → m a x p o o l 2 d → input \rightarrow conv2d \rightarrow relu \rightarrow maxpool2d \rightarrow conv2d \rightarrow relu \rightarrow maxpool2d \rightarrow inputconv2drelumaxpool2dconv2drelumaxpool2d v i e w → l i n e a r → r e l u → l i n e a r → r e l u → l i n e a r → M S E L o s s → l o s s view \rightarrow linear \rightarrow relu \rightarrow linear \rightarrow relu \rightarrow linear \rightarrow MSELoss \rightarrow loss viewlinearrelulinearrelulinearMSELossloss

则根据最后得到的loss可以逐步递归的求其每层的梯度,并实现权重更新。

在实现梯度反向传递时主要需要三步:

  • 初始化梯度值: net.zero_grad()
  • 反向求解梯度: loss.backward()
  • 更新参数:optimizer.step()

注意:对于一个输入input,经过网络计算得到output,在计算梯度就是output=>input的递归过程,在递归完图后会释放图的缓存,因此在第二次使用output进行梯度计算时会出现错,如下;

RuntimeError: Trying to backward through the graph second time, but the buffers have already been freed. 
Please specify retain_variables=True when calling backward for the first time.

下面我们拿GAN举一下例子。因为GAN由两个网络组成,其参数的更新也涉及到两个网络的交替。
简单介绍一下如下:
Generator生成新的数据,Discriminator用于判断数据是真实的还是生成的。通过训练Discriminator使得Discriminator可以准确的分辨数据的真伪,通过训练Generator使得Discriminator无法分辨真伪。

在这里插入图片描述
训练时先更新Discriminator,然后再更新Generator:
在这里插入图片描述
看上图中GAN更新方式,在1中需要使用真实数据和生成数据来更新Discriminator,但是生成数据由Generator得到,再传入到Discriminator中进行计算,因此进行梯度反向计算时会同时计算出Generator的梯度,并释放网络递归图的缓存,则在2中更新Generator时会报错。

在1中计算Discriminator梯度时不需要计算Generator的梯度,因此在使用生成数据计算Discriminator时使用gendata.detach()作为输入数据,这样就对当前图进行拆分,从而得到一个新的Tensor变量。

下面就用一段代码来验证一下这些功能,这段代码本身倒没有什么特别的功能。

构建两个网络A和B,首先使用A网络的结果计算B网络,然后更新B网络,最后更新A网络,这种情况与GAN相似,B网络需要使用A网络的结果进行计算,如果更新B网络,则连带着A网络的梯度也会计算,当最后更新B网络时,则会报错。

具体代码如下:

import torch
import torch.nn as nn
import torch.optim as optim

class A(nn.Module):
	def __init__(self):
		super(A, self).__init__()
		

Reference:
https://blog.csdn.net/u011276025/article/details/76997425

你可能感兴趣的:(Pytorch)