torch中的model.eval()、model.train()详解

个人简介: 深度学习图像领域工作者
工作总结链接:https://blog.csdn.net/qq_28949847/article/details/128552785
             链接中主要是个人工作的总结,每个链接都是一些常用demo,代码直接复制运行即可。包括:
                    1.工作中常用深度学习脚本
                    2.torch、numpy等常用函数详解
                    3.opencv 图片、视频等操作
                    4.个人工作中的项目总结(纯干活)
视频讲解: 以上记录,通过B站等平台进行了视频讲解使用,可搜索 ‘Python图像识别’ 进行观看
              B站:Python图像识别
              抖音:Python图像识别
              西瓜视频:Python图像识别


记录此博客原因:
在进行车牌识别时,发现trian的模型loss很低,效果很好,但是当预测时,效果就不好,经过排查是因为模型中添加了BN层和dropout层,但是在预测时,没有添加eval()导致的,下面是详细记录。

简介

在使用pytorch训练和预测时会分别使用到以下两行代码:

model.train()
model.eval()

下面对两行代码的具体作用进行记录,

1. model.train()

在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout 。

如果模型中有BN层(Batch Normalization)和 Dropout ,需要在 训练时 添加 model.train()。

model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。

2. model.eval()

model.eval(),简而言之,就是评估模式,而非训练模式。
在评估模式下,batchNorm层,dropout层等网络层会被关闭,不启用BN层和dropout层,使用训练时模型中保存的参数,从而使得评估时结果不会发生偏移。

在对模型进行评估时,应该配合使用with torch.no_grad()model.eval()

...
model.eval()
with torch.no_grad():
    Evaluation
...

如果模型中有batchNorm以及dropout等层,不添加model.eval()的话,结果是不可预料的。本人在进行评估LPRNet车牌识别时,就出现了这种情况,忘记添加eval(),结果飘忽不定。

不添加eval(),结果如下:
torch中的model.eval()、model.train()详解_第1张图片

添加eval(),结果如下:

torch中的model.eval()、model.train()详解_第2张图片
可以看出添加eval()后,模型的准确度立刻就上去了。

2.1 总结

如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。

model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。

1)训练过程中BN的变化。
在训练过程中BN会不断的计算均值和方差,训练结束后得到最终的均值和方差,在此处将其记为mean_train,variance_train。

2)预测过程中BN的变化。
预测过程中如果不使用model.eval()的话,BN层还是会根据输入的预测数据继续计算均值和方差,假设输入一条预测数据后,BN层计算得到其均值和方差分别为mean_test,variance_test,此时BN层的均值和方差则变成了(mean_train+mean_test),(variance_train+variance_test),相比于训练过程中的均值和方差发生了变化因此会导致预测结果发生变化。

如果使用model.eval()则BN层就不会再计算预测数据的均值和方差,即在预测过程中BN层的均值和方差就是训练过程得到的均值和方差mean_train,variance_train,此时预测结果就不会再发生变化。

3)训练过程中Dropout的变化
训练过程中依据设置的dropout比例会使一部分的网络连接不进行计算。

4)预测过程中Dropout的变化
预测过程中如果不使用model.eval()的话,依然会使一部分的网络连接不进行计算,而使用model.eval()后就是所有的网络连接均进行计算。

你可能感兴趣的:(python,opencv,计算机视觉)