torch.nn.Module源码学习

nn.Module是使用pytorch进行神经网络训练的主要载体,是所有网络的基类。首先看一下它的构造函数:

    def __init__(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._modules = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
  • self.training标志网络的状态,主要影响bn和dropout等在网络训练和评估时使用方法不一样的功能;
  • self._parameters保存当前module的训练参数;
  • self._buffers保存当前moduile的非训练参数,如bn的running_mean和running_var;
  • self._modules保存当前module中的子module(不是自定义模型中的所有module);
  • 其余的属性均是用户定义的hook函数。关于hook函数可参数:pytorch的自动求导和hook技术简介

下面分类来看nn.Module中的其他非辅助方法:

  1. 添加新元素:
  • register_buffer:向self._buffers注册新元素
  • register_parameter:向self._parameters注册新元素
  • add_module:向self._modules注册新元素
  1. 类型转换:
  • cuda:将所有的parameters和buffers移动到gpu

  • cpu:将所有的parameters和buffers移动到cpu

  • type:将所有的parameters和buffers都转换为指定的目标类型

  • float:将所有的parameters和buffers都转换为float类型

  • double:将所有的parameters和buffers都转换为double类型

  • half:将所有的parameters和buffers都转换为float16类型

  • to:该函数有四种用法:

    • to(device=None, dtype=None, non_blocking=False):转移到指定的device;
    • to(dtype, non_blocking=False):转换为指定的dtype;
    • to(tensor, non_blocking=False):将tensor属性(dtype和device)转换到与指定tensor相同;
    • to(memory_format=torch.channels_last):改变4d tensor的存储格式,NCHW或NHWC。
  • 上述所有函数的功能均借助_apply完成:

    def _apply(self, fn):
            for module in self.children():
                module._apply(fn)
    
            def compute_should_use_set_data(tensor, tensor_applied):
    		# 是否进行in-place操作
    
            for key, param in self._parameters.items():
                if param is not None:
                    # Tensors stored in modules are graph leaves, and we don't want to
                    # track autograd history of `param_applied`, so we have to use
                    # `with torch.no_grad():`
                    with torch.no_grad():
                        param_applied = fn(param)
                    should_use_set_data = compute_should_use_set_data(param, param_applied)
                    if should_use_set_data:
                        # 直接替换旧数据
                        param.data = param_applied
                    else:
                        assert isinstance(param, Parameter)
                        assert param.is_leaf
                        # 注册新的Parameter
                        self._parameters[key] = Parameter(param_applied, param.requires_grad)
    				
                    # 对param.grad进行相同的操作
                    if param.grad is not None:
                        with torch.no_grad():
                            grad_applied = fn(param.grad)
                        should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
                        if should_use_set_data:
                            param.grad.data = grad_applied
                        else:
                            assert param.grad.is_leaf
                            self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)
    
            # 更新_buffers
            for key, buf in self._buffers.items():
                if buf is not None:
                    self._buffers[key] = fn(buf)
    
            return self
    

    _apply函数遍历了所有的parameters(如果有grad还要遍历grad)和buffers,对它们应用fn,并将fn作用的结果注册到相应的容器中。

  1. hook注册:
  • register_backward_hook:向self._backward_hooks注册新元素

  • register_forward_pre_hook:向self._forward_pre_hooks注册新元素

  • register_forward_hook:向self._forward_hooks注册新元素

    关于hook技术的详细介绍可参考我的里一篇文章pytorch的自动求导和hook技术简介

  1. 模型保存和加载:
  • state_dict:返回一个包含模型所有parameters和buffers的字典。需要注意的是,所有的参数在被添加到字典中之前,都要经过self._state_dict_hooks中的hook函数处理。

    该函数是个递归函数,每次递归调用了辅助函数_save_to_state_dict来将当前module的子module参数添加到字典中。

  • load_state_dict:将指定的state_dict中的参数加载到当前的module中。需要注意的是,所有的参数在被添加到当前模型的_parameters和 _ buffers之前,都要经过self. _load_state_dict_pre_hooks中hook函数处理。

    该函数中定义了一个递归函数load。load每次递归时调用辅助函数_load_from_state_dict来添加当前module的子module参数。

    该函数有一个输入参数strict,默认为True。strict为True时,会返回一个namedtuple。该tuple有两个属性:missing_keys和unexpected_keys。missing_keys是存储当前module有而待加载的state_dict中没有的参数的列表;unexpected_keys是存储当前module没有而待加载的state_dict中有的参数的列表。

  1. 信息查询:
  • named_modules:返回模型中的所有module(包括模型本身)及其名称。返回值中有两个值,第一个为module的名(string格式),第二个为相应的module。返回顺序是自顶向下。

  • modules:返回模型中的所有module,调用named_modules完成。

  • named_children:返回当前模型_modules中的所有元素及其名称。注意该函数与named_modules的区别:named_modules是递归函数,每次递归均查询当前module的 _modules,进而能够遍历模型中的所有module;而named_chilldren只查询整个模型的 _modules

  • children:返回当前模型_modules中的所有元素,调用named_modules来完成。

  • named_parameters:返回模型中所有的可训练参数及其名称。

  • parameters:返回模型中所有的可训练参数。

  • named_buffers:返回模型中所有的非训练参数及其名称。

  • buffers:返回模型中所有的非训练参数。

    上述所有的方法返回的都是由yield定义的生成器。
    前两个方法借助named_modules来完成。

    def named_modules(self, memo=None, prefix=''):
          r"""Returns an iterator over all modules in the network, yielding
          both the name of the module as well as the module itself.
          Yields:
              (string, Module): Tuple of name and module
          Note:
              Duplicate modules are returned only once. In the following
              example, ``l`` will be returned only once.
          """
          if memo is None:
              memo = set()
          if self not in memo:
              memo.add(self)  # 包括当前模型自身
              yield prefix, self
            for name, module in self._modules.items():
                  if module is None:
                    continue
                  submodule_prefix = prefix + ('.' if prefix else '') + name
                  # 递归返回每一级所有module
                  for m in module.named_modules(memo, submodule_prefix):
                      yield m
    

    named_children和children直接返回_modules中的所有元素。
    最后四个方法调用了辅助函数_named_members来实现:

        def _named_members(self, get_members_fn, prefix='', recurse=True):
            r"""Helper method for yielding various names + members of modules."""
            memo = set()
            # 调用named_modules获得模型中的所有module
            modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
            for module_prefix, module in modules:
                members = get_members_fn(module) # 获得当前module的所有参数
                for k, v in members:
                    if v is None or v in memo:
                        continue
                    memo.add(v)
                    name = module_prefix + ('.' if module_prefix else '') + k
                    yield name, v
    

用一个简单的示例来学习他们的区别:

import torch
import torch.nn as nn

class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )
        self.fc = nn.Linear(32, 3)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 32)
        x = self.fc(x)
        return x

model = net()

for k, m in model._modules.items():
    print(k, m)
# conv Sequential(
#   (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
#   (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#   (2): ReLU(inplace=True)
# )
# fc Linear(in_features=32, out_features=3, bias=True)

for k, v in model.named_modules():
	print(k, v)
#  net(
#   (conv): Sequential(
#     (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
#     (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#     (2): ReLU(inplace=True)
#   )
#   (fc): Linear(in_features=32, out_features=3, bias=True)
# )
# conv Sequential(
#   (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
#   (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#   (2): ReLU(inplace=True)
# )
# conv.0 Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
# conv.1 BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# conv.2 ReLU(inplace=True)
# fc Linear(in_features=32, out_features=3, bias=True)

for k, v in model.named_children():
	print(k, v)
# conv Sequential(
#   (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
#   (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#   (2): ReLU(inplace=True)
# )
# fc Linear(in_features=32, out_features=3, bias=True)
# 注意children()和modules()的区别:children只返回一级子module

for k, _ in model.named_parameters():
	print(k)
# conv.0.weight
# conv.1.weight
# conv.1.bias
# fc.weight
# fc.bias

for k, _ in model.named_buffers():
	print(k)
# conv.1.running_mean
# conv.1.running_var
# conv.1.num_batches_tracked
  1. 状态设置:
  • train:将模型设为训练模式;主要影响BN、Dropout等module。
  • eval:将模型设为评估模式;主要影响BN、Dropout等module。调用train来实现。
  • require_grad_:设置模型中所有的Parameter的require_grad属性,即是否要计算梯度。调用parameters()来实现。
  • zero_grad:将模型中所有的Parameter的梯度置为零,并将其从计算图中分离。调用parameters()来实现。
  1. 其他:
  • apply:将自定义函数作用于模型的所有子module。

    def apply(self, fn):
        for module in self.children():
            module.apply(fn)
            fn(self)
            return self
    

你可能感兴趣的:(PyTorch)