pytorch nn.Module train和eval 函数 深入解析

1、先看看官方的解释

pytorch nn.Module train和eval 函数 深入解析_第1张图片pytorch nn.Module train和eval 函数 深入解析_第2张图片
模型调用eval() 就是设置为评估模式,调用train(True)就为训练模式,这个说的很不透彻,我们一起来探究究竟为何model模型就设置为了评估模式

#看看train的源码:
    def train(self: T, mode: bool = True) -> T:   
        if not isinstance(mode, bool):
            raise ValueError("training mode is expected to be boolean")
        self.training = mode
        for module in self.children():#这个很关键,让它的子模块里的training值也改变,才会随之nn.DropOut()的训练模式
            module.train(mode)
        return self

2、深度解析

先说结论:模型类继承了nn.Module 就有实例属性training。模型调用train() 【参数为mode,默认值为True】 会设置training值等于mode值。调用eval() 【没有参数】实际执行会设置training值为False,等同于train(False)。
而最后 training值会影响Dropout和BatchNorm的函数参数值的设置【使用或不使用】,一般的train(True)模式,使用Dropout和BatchNorm,而eval() Dropout和BatchNorm则不会"工作"。

#这里是nn.Module 的部分源码,可见self.training 的值默认为True
class Module:
    dump_patches: bool = False
    _version: int = 1
    training: bool
    _is_full_backward_hook: Optional[bool]
    def __init__(self):
        torch._C._log_api_usage_once("python.nn_module")
        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._non_persistent_buffers_set = set()
        self._backward_hooks = OrderedDict()
        self._is_full_backward_hook = None
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
#这是Dropout2d的源码,可见training值对它的影响
class Dropout2d(_DropoutNd):
	def forward(self, input: Tensor) -> Tensor:
	        return F.dropout2d(input, self.p, self.training, self.inplace)

#这里是BatchNorm2d(继承自nn.Module)的节选,可见,self.training为True,则工作。而self.training为False也并不会不工作,而是看均值和方差是否之前计算出。

 			if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:  # type: ignore[has-type]
                self.num_batches_tracked = self.num_batches_tracked + 1  # type: ignore[has-type]
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum
 		if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)

你可能感兴趣的:(pytorch,pytorch,深度学习,python)