Pytorch中model.train()和model.eval()的作用

我们在使用pytorch训练模型的时候会先加上一句

model.train()

模型训练完做推理时,也会先加上一句

model.eval()

这两句话的作用是告诉模型当前是在训练还是推理阶段。因为我们的模型的某些部分在做训练和推理时的操作是不一样的,如BN层的计算过程,Faster RCNN在训练和推理时预选框的选择等等。
那么这两句话背后是做了什么操作来告诉模型当前阶段是训练还是推理呢?其实train()eval()方法是在torch的Module类中实现的。源码如下

class Module(object):
    _version = 1

    def __init__(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
	......
	......
	......
	......
	......
    def train(self, mode=True):
        r"""Sets the module in training mode.

        This has any effect only on certain modules. See documentations of
        particular modules for details of their behaviors in training/evaluation
        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
        etc.

        Args:
            mode (bool): whether to set training mode (``True``) or evaluation
                         mode (``False``). Default: ``True``.

        Returns:
            Module: self
        """
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

    def eval(self):
        r"""Sets the module in evaluation mode.

        This has any effect only on certain modules. See documentations of
        particular modules for details of their behaviors in training/evaluation
        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
        etc.

        This is equivalent with :meth:`self.train(False) `.

        Returns:
            Module: self
        """
        return self.train(False)

我们可以看到Module类中有定义一个参数training,并初始化为True,

self.training=True

我们自己在写模型网络时,会先继承torch.nn.Module,

class Network(nn.Module):
	"""
	"""

模型搭建完成后,先对模型进行初始化,

model=Network()

此时model就继承了torch.nn.Module,执行model.train()时,实际执行的操作是在Module的train()方法,将模型的参数training设置为True,并且每个子代Module的training设置为True。

    def train(self, mode=True):
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

执行model.eval()时,实际执行的操作是在Module的eval()方法,eval()通过调用train(),传入False的参数,将training设置为False。

    def eval(self):
        return self.train(False)

你可能感兴趣的:(深度学习,torch,tain,eval)