pytorch判断模型是否处于训练状态

使用self.training,这样就可以让forward函数采用两种执行方式,然后就可以做一些骚操作了

import torch.nn as nn

class myNet(nn.Module):
    def __init__(self):
        super(myNet, self).__init__()
    def forward(self):
        if self.training:
            print('training')
        else:
            print('not training')
            
model = myNet()
model.train()
result = model()
model.eval()
result = model()

你可能感兴趣的:(Python,机器学习,python,机器学习,神经网络,pytorch)