Pytorch学习笔记五——net.train与net.eval

net.train() 和 net.eval() 两个函数只要适用于Dropout与BatchNormalization的网络,会影响到训练过程中这两者的参数。

  • net.train()时,训练时每个min-batch时都会根据情况进行上述两个参数的相应调整
  • net.eval()时,由于网络训练完毕后参数都是固定的,因此每个批次的均值和方差都是不变的,因此直接结算所有batch的均值和方差。所有Batch Normalization的训练和测试时的操作不同。
class Model1(nn.Module):
    
    def __init__(self):
        super(Model1, self).__init__()
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        return self.dropout(x)

m1 = Model1()
inputs = torch.ones(10)

print(inputs)
print(20 * '-' + "train model:" + 20 * '-' + '\r\n')
print(m1(inputs))
print(20 * '-' + "eval model:" + 20 * '-' + '\r\n')
m1.eval()
print(m1(inputs))
"""
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
--------------------train model:--------------------

tensor([0., 2., 0., 2., 2., 0., 0., 0., 2., 0.])
--------------------eval model:--------------------

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
"""

你可能感兴趣的:(Pytorch学习笔记五——net.train与net.eval)