【Pytorch】model.eval() vs torch.no_grad()

Date: 2020/10/12

Coder: CW

Foreword:

在将模型用于推断时,我们通常都会将 model.eval() 和 torch.no_grad() 一起使用,那么这两者可否单独使用?它们的区别又在哪里?你清楚吗?之前 CW 对这块也有些迷糊,于是自己做了实验,最终得到了结论,实验过程就不在本文叙述了,仅抛出结论作为参考。


model.eval()

Pytorch 的模型都继承自 torch.nn.Module,这个类有个training 属性,而这个方法会将这个属性值设置为False,从而影响一些模型在前向反馈过程中的操作,比如 BN 和 Dropout 层。在这种情况下,BN层不会统计每个批次数据的均值和方差,而是直接使用在基于训练时得到的均值和方差;Dropout层则会让所有的激活单元都通过。

同时,梯度的计算不受影响,计算流依然会存储和计算梯度,反向传播后仍然能够更新模型的对应的权重(比如BN层的weight和bias依然能够被更新)。


torch.no_grad()

通常是通过上下文的形式使用:

with torch.no_grad():

    your evaluation code

这种情况将停止autograd模块的工作,即不会自动计算和存储梯度,因此能够起到加速计算过程和节省显存的作用,同时也说明了不能够进行反向传播以更新模型权重


Summary

由上可知,在推断时将 model.eval() 与 torch.no_grad() 搭配使用,主要是出于以下几点考虑:

i). 模型中使用了诸如 BN 和 Dropout 这样的网络层,需要使用 model.eval() 来改变它们在前向过程中的操作;

ii). 为了加速计算过程和节省显存,使用torch.no_grad()

你可能感兴趣的:(【Pytorch】model.eval() vs torch.no_grad())