Pytorch框架学习 -2 torch.nn.modules.Module(nn.Module)理解

文章目录

  • Pytorch框架学习 -2 torch.nn.modules.Module(nn.Module)理解
    • 最简单的例子
      • 分析
    • 部分源码:
      • 基本参数
        • dump_patches
        • _version
        • training
      • 初始化函数
        • _parameters
        • _buffers
        • _modules
        • 其他属性
      • forward
      • 一些注册器
        • register_buffer
        • register_parameter
        • add_module
      • zero_grad()
      • train
    • 一个复杂一些的例子

Pytorch框架学习 -2 torch.nn.modules.Module(nn.Module)理解

最简单的例子

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

分析

  1. 一个Pytroch模型应该以类的形式出现
  2. Pytorch训练模型应该是nn.Module的子类
  3. 一个训练模型包含经过初始化和前向传播两个过程

初始化模型是为了注册参数,保证模型能够正常处理这些重要参数,显然是必要
不同神经网络的前向传播过程肯定要自己定义,否则这个模型就失去了独特性

部分源码:

基本参数

class Module:
    dump_patches: bool = False
    _version: int = 1
    training: bool

dump_patches

当调用.to()|.cuda()的时候,将参数也将转化为gpu类型

_version

用于之后函数比较版本

training

使用train(mode)方法时修改,默认在init时变为True

主要影响bn和dropout等在网络训练和评估时使用方法不一样的功能

初始化函数

    def __init__(self):
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._non_persistent_buffers_set = set()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()

_parameters

保存当前module的训练参数

_buffers

保存当前moduile的非训练参数

_modules

保存当前module中的子module

其他属性

来自用户定义的hook函数。

参照:自动求导和hook技术

forward

def _forward_unimplemented(self, *input: Any) -> None:
    raise NotImplementedError

forward: Callable[..., Any] = _forward_unimplemented

预设forward为_forward_unimplemented()方法,对于任意输入抛出异常,需要在子类中实现forward来覆盖这个变量,感觉像是虚函数,但是实现巧妙

一些注册器

    def register_buffer(self, name: str, tensor: Tensor, persistent: bool = True) -> None:
        if persistent is False and isinstance(self, torch.jit.ScriptModule):
            raise RuntimeError("ScriptModule does not support non-persistent buffers")

        if '_buffers' not in self.__dict__:
            raise AttributeError(
                "cannot assign buffer before Module.__init__() call")
        elif not isinstance(name, torch._six.string_classes):
            raise TypeError("buffer name should be a string. "
                            "Got {}".format(torch.typename(name)))
        elif '.' in name:
            raise KeyError("buffer name can't contain \".\"")
        elif name == '':
            raise KeyError("buffer name can't be empty string \"\"")
        elif hasattr(self, name) and name not in self._buffers:
            raise KeyError("attribute '{}' already exists".format(name))
        elif tensor is not None and not isinstance(tensor, torch.Tensor):
            raise TypeError("cannot assign '{}' object to buffer '{}' "
                            "(torch Tensor or None required)"
                            .format(torch.typename(tensor), name))
        else:
            self._buffers[name] = tensor
            if persistent:
                self._non_persistent_buffers_set.discard(name)
            else:
                self._non_persistent_buffers_set.add(name)

    def register_parameter(self, name: str, param: Parameter) -> None:
        if '_parameters' not in self.__dict__:
            raise AttributeError(
                "cannot assign parameter before Module.__init__() call")

        elif not isinstance(name, torch._six.string_classes):
            raise TypeError("parameter name should be a string. "
                            "Got {}".format(torch.typename(name)))
        elif '.' in name:
            raise KeyError("parameter name can't contain \".\"")
        elif name == '':
            raise KeyError("parameter name can't be empty string \"\"")
        elif hasattr(self, name) and name not in self._parameters:
            raise KeyError("attribute '{}' already exists".format(name))

        if param is None:
            self._parameters[name] = None
        elif not isinstance(param, Parameter):
            raise TypeError("cannot assign '{}' object to parameter '{}' "
                            "(torch.nn.Parameter or None required)"
                            .format(torch.typename(param), name))
        elif param.grad_fn:
            raise ValueError(
                "Cannot assign non-leaf Tensor to parameter '{0}'. Model "
                "parameters must be created explicitly. To express '{0}' "
                "as a function of another Tensor, compute the value in "
                "the forward() method.".format(name))
        else:
            self._parameters[name] = param
    def add_module(self, name: str, module: 'Module') -> None:
        if not isinstance(module, Module) and module is not None:
            raise TypeError("{} is not a Module subclass".format(
                torch.typename(module)))
        elif not isinstance(name, torch._six.string_classes):
            raise TypeError("module name should be a string. Got {}".format(
                torch.typename(name)))
        elif hasattr(self, name) and name not in self._modules:
            raise KeyError("attribute '{}' already exists".format(name))
        elif '.' in name:
            raise KeyError("module name can't contain \".\"")
        elif name == '':
            raise KeyError("module name can't be empty string \"\"")
        self._modules[name] = module

register_buffer

self._buffers[name] = tensor
if persistent:
	self._non_persistent_buffers_set.discard(name)
else:
	self._non_persistent_buffers_set.add(name)

实际上,经过对传入参数的判断,例如类型不合法,名称不合法,之后仅有上面五行代码:

将传入tensor存入_buffer

判断是否需要持久化

将不需要持久化的buffer名称存进self._non_persistent_buffers_set

面临persistent可能被修改,使用discard()移除元素,防止remove()移除不存在元素报错,又减少查找次数QwQ

register_parameter

if param is None:
	self._parameters[name] = None
elif not isinstance(param, Parameter):
	raise TypeError("cannot assign '{}' object to parameter '{}' (torch.nn.Parameter or None required)"
                    .format(torch.typename(param), name))
elif param.grad_fn:
    raise ValueError(
        "Cannot assign non-leaf Tensor to parameter '{0}'. Model "
        "parameters must be created explicitly. To express '{0}' "
        "as a function of another Tensor, compute the value in "
        "the forward() method.".format(name))
else:
	self._parameters[name] = param

经过类似buffer的判断:

  1. 如果参数是None显然可以存进_parameters
  2. 如果他不是Parameter的实例,显然不难作为参数
  3. 如果他又grad_fn这个为反向传播设计的变量,说明他是经过运算得到的,不应该作为参数存储
  4. 将参数存入_parameters

add_module

进行合法化判断之后将module存入_modules

zero_grad()

    def zero_grad(self) -> None:
        if getattr(self, '_is_replica', False):
            warnings.warn(
                "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
                "The parameters are copied (in a differentiable manner) from the original module. "
                "This means they are not leaf nodes in autograd and so don't accumulate gradients. "
                "If you need gradients in your forward method, consider using autograd.grad instead.")

        for p in self.parameters():
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()

将模型中计算的梯度清空,否则一次计算的梯度会是目前所以梯度的和,因为每处理一组数据,就会计算一次梯度,之后再计算梯度之前应该清空。

train

    def train(self: T, mode: bool = True) -> T:
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

更新该模型及其包含的其他模型的training属性,关于training的叙述参上

一个复杂一些的例子

你可能感兴趣的:(框架学习,人工智能,机器学习,pytorch)