nn.Module介绍(一)

目录

文章目录

  • 目录
  • 前言
  • 一、构造一个极简网络
  • 二、初始化部分介绍
  • 三、添加一个卷积层的简单网络
  • 四、通过简单网络洞察nn.Module的其他属性和方法
  • 总结


前言


一、构造一个极简网络

import torch
import torch.nn as nn
from collections.abc import Iterable, Iterator

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
    def forward(self,x):
        ...

if __name__ == '__main__':
    net = Net()

  首先Net继承自nn.Module,通过super(python中的超类)完成父类的初始化。

二、初始化部分介绍

  在nn.Module()代码中,完成初始化方法通过python的魔法函数__setattr__完成。简单介绍下该魔法函数:setattr (object,name,value)用于设置当前对象(object)的属性(name)值(value)。当然,name属性不一定存在。简单有个概念即可,不理解魔法函数也没关系。
  看下nn.Module中初始化部分的代码。

class Module:
    def __init__(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        # 初始化的属性
        self.training = True
        self._parameters = OrderedDict()                # 存储模型参数,参与BP
        self._buffers = OrderedDict()                   #  中间变量,不参与BP
        self._non_persistent_buffers_set = set()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()                   # 添加模块:比如conv/bn等
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
    def remove_from(*dicts_or_sets):
        for d in dicts_or_sets:
            if name in d:
                if isinstance(d, dict):
                    del d[name]
                else:
                    d.discard(name)
    params = self.__dict__.get('_parameters')
    if isinstance(value, Parameter):
        if params is None:
            raise AttributeError(
                "cannot assign parameters before Module.__init__() call")
        remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
        self.register_parameter(name, value)                  #注册进self._parameters
    elif params is not None and name in params:
        if value is not None:
            raise TypeError("cannot assign '{}' as parameter '{}' "
                            "(torch.nn.Parameter or None expected)"
                            .format(torch.typename(value), name))
        self.register_parameter(name, value)                  #注册进self._parameters
    else:
        modules = self.__dict__.get('_modules')
        if isinstance(value, Module):
            if modules is None:
                raise AttributeError(
                    "cannot assign module before Module.__init__() call")
            remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
            modules[name] = value                               # 注册进modules
        elif modules is not None and name in modules:
            if value is not None:
                raise TypeError("cannot assign '{}' as child module '{}' "
                                "(torch.nn.Module or None expected)"
                                .format(torch.typename(value), name))
            modules[name] = value                               # 注册进modules
        else:
            buffers = self.__dict__.get('_buffers')
            if buffers is not None and name in buffers:
                if value is not None and not isinstance(value, torch.Tensor):
                    raise TypeError("cannot assign '{}' as buffer '{}' "
                                    "(torch.Tensor or None expected)"
                                    .format(torch.typename(value), name))
                buffers[name] = value                           # 注册进buffers
            else:
                object.__setattr__(self, name, value)

  这里简单介绍下流程:以完成“self._parameters()属性”注册为例。当程序遇见self._parameters() 属性时,将自动执行def setattr(self,name,value)函数:此处self表示当前类本身。在初始化中,name = self._parameters(); value =OrderDict(); 之后在setattr中,会依次判断当前属性name是否是Parameter类/Module类/buffer。由于当前OrderDict都不是,因此,执行最底下代码:object.setattr(self,name,value)完成属性 self._parameters()的添加。
  由于本代码中初始化均是OrderDict(),因此,实质上在初始化这个极简网络时,均是执行了 最后一行代码,即object.setattr(self,name,value)

三、添加一个卷积层的简单网络

 在上一节中主要介绍了通过super借助__setattr__完成了怎样的初始化。假如现在添加一个卷积层,初始化部分会有何不同呢?

import torch
import torch.nn as nn
from collections.abc import Iterable, Iterator

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.lele= nn.Conv2d(1, 1, 1, 1, 0)

    def forward(self,x):
        ...

if __name__ == '__main__':
    net = Net()

 在第1小节中,实质上执行完了super(Net,self).init()语句。现在程序添加了一个卷积层self.lele=nn.Conv2d(1,1,1,1,0):当然,本节我们只关注Module的学习,并不关注卷积层是如何初始化的,即Conv2d类如何创建本文不讨论。本文默认已经实例了一个self.lele对象。现在核心是:如何将这个self.lele对象添加进Module初始化中呢?由于往类中添加了新的属性,因此,首先执行__setattr__函数。object = self , name=lele, value= nn.Conv2d类。
观察此时函数内部的运行情况。
 这里简单介绍下nn.Conv2d类。在pytorch中,Conv2d类继承自_ConvNd类。而_ConvNd类继承自nn.Module(); _ConvNd主要完成卷积核的一些初始化参数工作。而具体到Conv2d/Conv1d类时,主要负责前向传播的运算任务。
 Okay,言归正传。在执行到self.lele=nn.Conv2d(1,1,1,1,0)这句代码时:首先肯定完成卷积初始化工作,即完成_ConvNd中参数工作。以Conv2d中in_channel参数为例,首先也是通过__setattr__方法判断in_channel是否是Parameter/Module/Buffer,肯定均不属于。因此,同样也是执行object.setattr(self,name,value)这句代码。即将in_channel这种参数往卷积核类中添加了in_channel属性。
 当然,_ConvNd中还有好多参数:同样采用上段中的方法完成了参数注册:

in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t,
padding: _size_1_t,
dilation: _size_1_t,
transposed: bool,
output_padding: _size_1_t,
groups: int,
bias: Optional[Tensor],
padding_mode: str) -> None:

 Conv2d中有weights(卷积核权重)和bias(卷积核偏置)两个Parameter类,即Conv2d中待学习的参数有两个weights和bias。在torch源码中长这样:
那么,如何将这两个Parameter类注册进self._parameters()?
同理:在__setattr__函数中:

def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
...
    if isinstance(value, Parameter):               #判断若value是Parameter类
        self.register_parameter(name, value)       # 则注册进self._parameter 

 简单说就是判断下value的类型,若为Parameter,则注册进_parameters()。这里在看下register_parameter()函数代码:

def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
    r"""Adds a parameter to the module.

    The parameter can be accessed as an attribute using given name.

    Args:
        name (string): name of the parameter. The parameter can be accessed
            from this module using the given name
        param (Parameter): parameter to be added to the module.
    """
    if '_parameters' not in self.__dict__:
        raise AttributeError(
            "cannot assign parameter before Module.__init__() call")

    elif not isinstance(name, torch._six.string_classes):
        raise TypeError("parameter name should be a string. "
                        "Got {}".format(torch.typename(name)))
    elif '.' in name:
        raise KeyError("parameter name can't contain \".\"")
    elif name == '':
        raise KeyError("parameter name can't be empty string \"\"")
    elif hasattr(self, name) and name not in self._parameters:
        raise KeyError("attribute '{}' already exists".format(name))

    if param is None:
        self._parameters[name] = None
    elif not isinstance(param, Parameter):
        raise TypeError("cannot assign '{}' object to parameter '{}' "
                        "(torch.nn.Parameter or None required)"
                        .format(torch.typename(param), name))
    elif param.grad_fn:
        raise ValueError(
            "Cannot assign non-leaf Tensor to parameter '{0}'. Model "
            "parameters must be created explicitly. To express '{0}' "
            "as a function of another Tensor, compute the value in "
            "the forward() method.".format(name))
    else:
        self._parameters[name] = param

 从上述代码可以看出:实质上就是判断Parameter是否符合各种规范,比如命名啥的。若全都符合,则执行最后一句代码,完成注册。
 当然:理解完上述完全ok。还有一个细节涉及到python的又一个高级操作:就是标黄的部分:hasattr,python用来判断对象是否具有某个属性。在pytorch中重写了这个魔法方法:即hasattr实质上跳转到了__getattr__中。看下__getattr__(object_name,name)—用来获取对象的属性:

def __getattr__(self, name: str) -> Union[Tensor, 'Module']:
    if '_parameters' in self.__dict__:
        _parameters = self.__dict__['_parameters']
        if name in _parameters:
            return _parameters[name]
    if '_buffers' in self.__dict__:
        _buffers = self.__dict__['_buffers']
        if name in _buffers:
            return _buffers[name]
    if '_modules' in self.__dict__:
        modules = self.__dict__['_modules']
        if name in modules:
            return modules[name]
    raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
        type(self).__name__, name))

 逻辑也比较简单:若具有这个属性,则通过字典进行返回:_parameters[name]。通过上述方式,就完成了Conv2d中weights和bias两个Parameter类的注册,同时也完成了Conv2d中所有参数的注册。
 上述实质上仅仅完成了Conv2d中参数的初始化,现在,实质上实例化出了一个名为lele的卷积核对象,属于一个Module类。因此,还需要通过__setattr__将lele这个Module类注册进self._modules()。具体通过:

def add_module(self, name: str, module: Optional['Module']) -> None:
    r"""Adds a child module to the current module.

    The module can be accessed as an attribute using the given name.

    Args:
        name (string): name of the child module. The child module can be
            accessed from this module using the given name
        module (Module): child module to be added to the module.
    """
    if not isinstance(module, Module) and module is not None:
        raise TypeError("{} is not a Module subclass".format(
            torch.typename(module)))
    elif not isinstance(name, torch._six.string_classes):
        raise TypeError("module name should be a string. Got {}".format(
            torch.typename(name)))
    elif hasattr(self, name) and name not in self._modules:
        raise KeyError("attribute '{}' already exists".format(name))
    elif '.' in name:
        raise KeyError("module name can't contain \".\"")
    elif name == '':
        raise KeyError("module name can't be empty string \"\"")
    self._modules[name] = module

 类似于注册register_parameter,此出不在赘述。
 对于lele来说:因为lele从nn.Conv2d实例化的对象而来。所以,lele作为一个module(容器)包裹了待优化的参数(weights,bias)。即可以通过module._parameters查看待优化的参数。

四、通过简单网络洞察nn.Module的其他属性和方法

 上述其实已经完成了一个网络的构建,只要实现forward方法,就能运行了。但是在现有大多数pytorch参考书中:会告诉你诸如“调用.children()方法”。我就奇了怪了,咋调用的?因此,本节主要分析nn.Module中如何实现查看一个网络的module,parameter,buffer等。
 截取nn.Module中查看module等属性的代码:还是简单网络为例。

def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
    for name, param in self.named_parameters(recurse=recurse):
        yield param

def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
    gen = self._named_members(
        lambda module: module._parameters.items(),
        prefix=prefix, recurse=recurse)
    for elem in gen:
        yield elem

 上述代码用来查看parameters。通过yield返回了一个生成器。生成器不理解请看另一篇博客。就是打印每个模块的参数。
 同时,nn.Module还提供了查看其它参数的方法,代码类似。

def buffers(self, recurse: bool = True) -> Iterator[Tensor]:

    for name, buf in self.named_buffers(recurse=recurse):
        yield buf
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
      gen = self._named_members(
        lambda module: module._buffers.items(),
        prefix=prefix, recurse=recurse)
    for elem in gen:
        yield elem
def children(self) -> Iterator['Module']:
    for name, module in self.named_children():
        yield module

def named_children(self) -> Iterator[Tuple[str, 'Module']]:
       memo = set()
    for name, module in self._modules.items():
        if module is not None and module not in memo:
            memo.add(module)
            yield name, module
def modules(self) -> Iterator['Module']:
    for name, module in self.named_modules():
        yield module

def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''):
     if memo is None:
        memo = set()
    if self not in memo:
        memo.add(self)
        yield prefix, self
        for name, module in self._modules.items():
            if module is None:
                continue
            submodule_prefix = prefix + ('.' if prefix else '') + name
            for m in module.named_modules(memo, submodule_prefix):
                yield m

 上述全部都是生成器。下面我给出一个如何使用这些方法的示例代码(写了我好久,太麻烦了。。。):

# -*- coding: utf-8 -*-
# ======================================================
# @Time    : 2021/01/04
# @Author  : lele wu
# @Email   : [email protected]
# @File    : lele_module.py
# @Comment: 研究nn.Module属性和方法
# ======================================================

import torch
import torch.nn as nn
from collections.abc import Iterable, Iterator

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.lele = nn.Conv2d(1, 1, 1, 1, 0)
        #self.bn = nn.BatchNorm2d(1)
        #self.relu = nn.ReLU(inplace=True)

    def forward(self,x):
        #out = self.conv(x)
        #out = self.bn(out)
        #out = self.relu(out)
        #return out
        ...

if __name__ == '__main__':

    # input = torch.Tensor([[[[-1,0],[1,2]]]])
    net = Net()
    # 查看网络的子模块
    print('net.children返回一个迭代器吗?:', isinstance(net.children(), Iterator)) # 因为使用了yeild生成器表达式,故是生成器
    print('***********************************************************************************************************')
    print('通过调用children()方法:')
    for module in net.children():
        print('当前module类为:',module)
    print('***********************************************************************************************************')
    print('通过调用named_children()方法:')
    for name,module in net.named_children():
        print('类名:',name,'运算模块:',module)  # name: 就是自己定义网络中 conv,bn,bn
    print('***********************************************************************************************************')
    print('通过调用modules()方法:')
    for module in net.modules():
        print('当前module类为:',module)
    print('***********************************************************************************************************')
    print('通过调用named_modules()方法:')
    for module in net.named_modules():
        print('当前module类为:',module)
    print('***********************************************************************************************************')
    print('总结: named_modules或者modules方法递归调用所有子类;而children和named_children方法则仅调用第一层子类:\n'
          '参考资料: https://zhuanlan.zhihu.com/p/65105409?utm_source=wechat_session')
    print('***********************************************************************************************************')
    print('通过调用buffers()方法:')
    for buffer in net.buffers():
        print('当前buffer为:',buffer)
    print('***********************************************************************************************************')
    print('通过调用named_buffers()方法:')
    for buffer in net.named_buffers():
        print('当前buffer(元祖)为:',buffer)
    print('***********************************************************************************************************')

 此处我放张运行效果图得了:
nn.Module介绍(一)_第1张图片
 当然,由于仅仅有一个卷积核。因此buffers为空,即卷积核在训练过程中不产生缓存。

总结

  这是第一篇,写的有点儿多。后续会介绍hook,apply等。当然,下节会首先介绍如何冻结参数。

你可能感兴趣的:(#,nn.Module,人工智能,编程语言)