requires_grad&volatile在冻结训练&预训练中的使用

每个变量都有两个标志:requires_gradvolatile。它们都允许从梯度计算中精细地排除子图,并可以提高效率。

requires_grad

       如果有一个单一的输入操作需要梯度,它的输出也需要梯度。相反,只有所有输入都不需要梯度,输出才不需要。如果其中所有的变量都不需要梯度进行,后向计算不会在子图中执行。

>>> x = Variable(torch.randn(5, 5))
>>> y = Variable(torch.randn(5, 5))
>>> z = Variable(torch.randn(5, 5), requires_grad=True)
>>> a = x + y
>>> a.requires_grad
False
>>> b = a + z
>>> b.requires_grad
True

       这个标志特别有用,当您想要冻结部分模型时,或者您事先知道不会使用某些参数的梯度。例如,如果要对预先训练的CNN进行优化,只要切换冻结模型中的requires_grad标志就足够了,直到计算到最后一层才会保存中间缓冲区,其中的仿射变换将使用需要梯度的权重并且网络的输出也将需要它们。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)

# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

volatile

纯粹的inference模式下推荐使用volatile,当你确定你甚至不会调用.backward()时。它比任何其他自动求导的设置更有效——它将使用绝对最小的内存来评估模型。volatile也决定了require_grad is False

volatile不同于require_grad的传递。如果一个操作甚至只有有一个volatile的输入,它的输出也将是volatileVolatility比“不需要梯度”更容易传递——只需要一个volatile的输入即可得到一个volatile的输出,相对的,需要所有的输入“不需要梯度”才能得到不需要梯度的输出。使用volatile标志,您不需要更改模型参数的任何设置来用于inference。创建一个volatile的输入就够了,这将保证不会保存中间状态。

>>> regular_input = Variable(torch.randn(5, 5))
>>> volatile_input = Variable(torch.randn(5, 5), volatile=True)
>>> model = torchvision.models.resnet18(pretrained=True)
>>> model(regular_input).requires_grad
True
>>> model(volatile_input).requires_grad
False
>>> model(volatile_input).volatile
True
>>> model(volatile_input).creator is None
True

pytorch中的 requires_grad和volatile - 牧马人夏峥 - 博客园

简单总结其用途

(1)requires_grad=Fasle时不需要更新梯度, 适用于冻结某些层的梯度;

 (2)volatile=True相当于requires_grad=False,适用于推断阶段,不需要反向传播。这个现在已经取消了,使用with torch.no_grad()来替代

pytorch学习笔记(三):自动求导_u012436149的博客-CSDN博客

pytorchBP过程是由一个函数决定的,loss.backward(), 可以看到backward()函数里并没有传要求谁的梯度。那么我们可以大胆猜测,在BP的过程中,pytorch是将所有影响lossTensor都求了一次梯度。**但是有时候,我们并不想求所有Tensor的梯度。**那就要考虑如何在Backward过程中排除子图(ie.排除没必要的梯度计算)。
如何BP过程中排除子图? 这就用到了Tensor中的一个参数requires_grad

为什么要排除子图

也许有人会问,梯度全部计算,不更新的话不就得了。
这样就涉及了效率的问题了,计算很多没用的梯度是浪费了很多资源的(时间,计算机内存)

你可能感兴趣的:(深度学习理论,深度学习,神经网络)