论model.eval()的重要性

https://blog.csdn.net/qq_38410428/article/details/101102075

model.train()和model.eval()很重要是因为:Batch Normalization和Dropout两层。

如果模型中有BN层(Batch Normalization)和 Dropout(),需要在训练时添加model.train(),测试时添加model.eval()。以此保证在测试时保留特定的神经连接路径(dropout),以及不再更新全局均值和方差(BN)(全局值不是因为反向传播更新的而是每次有数据由momentum控制更新的)。

model.train()和model.eval()的添加位置:

def train(model, optimizer, epoch, train_loader, validation_loader):
    model.train() 
    """
    错误的位置
    """
    for batch_idx, (data, target) in experiment.batch_loop(iterable=train_loader):
        model.train()  
       """
        正确的位置,保证每一个batch都能进入model.train()的模式
       """
        data, target = Variable(data), Variable(target)
        # Inference
        output = model(data)
        ...
def test(model, test_loader):
    model.eval()
    ...

你可能感兴趣的:(论model.eval()的重要性)