【pytorch】model.train()和model.evel()的用法

1.model.train()与model.eval()的用法

看别人的面经时,浏览到一题,问的就是这个。自己刚接触pytorch时套用别人的框架,会在训练开始之前写上model.trian(),在测试时写上model.eval()。然后自己写的时候也就保留了这个习惯,没有去想其中原因。

在经过一番查阅之后,总结如下:
如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train()是保证BN层用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接。

联系Batch Normalization和Dropout的原理之后就不难理解为何要这么做了。

 

不过我还是有疑惑,到底放在下面训练的里面还是在外面?

for img,label in train_loder:

     img,label = img.to(device),label.to(device)

    ...

 

参考:https://www.cnblogs.com/luckyplj/p/13424561.html

你可能感兴趣的:(PyTorch)