Pyotrch中的Module类

Module类

文章目录

  • Module类
    • 简介
    • Module的构造函数
    • Module的层级关系
    • 数据成员
      • 主要数据成员
      • 其他数据成员
      • 成员访问
        • Buffer 成员的访问
        • Modules 成员的访问
        • Parameters 成员的访问
        • 直接访问
      • 增加和删减数据成员
        • 直接增加和删减
        • 增加新的数据成员
        • 增加module
    • 数据转换
    • 网络状态的切换
    • 前向传递
    • Hook
    • pickle相关
    • 网络参数的初始化
    • 序列化
    • `__dir__`和 `__repr__`
    • zero grad

简介

Module作为所有神经网络模块的基类. 新的Module也应该继承这个类。一个Module还可以包含其他Module,允许将它们嵌套在树形结构中。另外可以将子Module分配为常规的类属性(attribute)。我们以AlexNet作为一个简单的例子,结合Module的定义来说明。

import torch.nn as nn
class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x

Module的构造函数

每一个Module都有各自的_parameters,_buffers,_modules 等参数。

class Module(object):
  def __init__(self):
    self._backend = thnn_backend
    self._parameters = OrderedDict()
    self._buffers = OrderedDict()
    self._backward_hooks = OrderedDict()
    self._forward_hooks = OrderedDict()
    self._forward_pre_hooks = OrderedDict()
    self._modules = OrderedDict()
    self.training = True

Module的层级关系

对于一个Module类,._modules是一个有序字典, 每个值也是一个字典. 所以Module通过递归地定义不同的children modules来定义整个网络的树形结构。每一个子Module都是Module类. 下面的结构是子模型的树形结构.

._modules.values()[0]
...
._modules.values()[n]
...
._modules.values()[0]
...
._modules.values()[n]
...
module
module.0
module.0.0
...
modele.n
...
...
modele.0.n
...

对于AlexNet来说,它的树形结构如下:

graph LR
AlexNet["AlexNet"]
features["'features':features"]
classifier["'classifier':classifier"]
AlexNet -- "_modules.values()[0]" --> features
AlexNet -- "_modules.values()[1]" --> classifier
subgraph AlexNeat
	AlexNet
	subgraph Sequential
	classifier --> classifier.0["'0':Dropout"]
	classifier --> classifier.1["'1':Linear"]
	classifier --> classifier.2["'2':ReLU"]
	classifier --> classifier.3["'3':Dropout"]
	classifier --> classifier.4["'4':Linear"]
	classifier --> classifier.5["'5':ReLU"]
	classifier --> classifier.6["'6':Linear"]
	end
	subgraph Sequential
	features --> features.0["'0':Conv2d"]
	features --> features.1["'1':ReLU"]
	features --> features.2["2':MaxPool2d"]
	features --> features.3["'3':Conv2d"]
	features --> features.4["'4':ReLU"]
	features --> features.5["'5':MaxPool2d"]
	features --> features.6["'6':Conv2d"]
	features --> features.7["'7':ReLU"]
	features --> features.8["'8':Conv2d"]
	features --> features.9["'9':ReLU"]
	features --> features.10["'10':Conv2d"]
	features --> features.11["'11':ReLU"]
	features --> features.12["'12':MaxPool2d"]
	end
end

数据成员

主要数据成员

主要的数据成员有_parameters,_buffers,_modules

这三个成员的类型都是OrderedDict类型

其他数据成员

成员访问

每一层module都有各自的参数,而完整module的参数是所有子module的集合。因此,访问参数的实现通过递归地访问返回每个子module的参数。

Buffer 成员的访问

    def _all_buffers(self, memo=None):
        if memo is None:
            memo = set()
        for name, b in self._buffers.items():
            if b is not None and b not in memo:
                memo.add(b)
                yield b
        for module in self.children():
            for b in module._all_buffers(memo):
                yield b

Modules 成员的访问

children()named_children() 返回当前module的子module.对于AlexNet返回的是两个子module

>> for child in alexnet.children():
>>     print(child.__class__.__name__)
Sequential
Sequential
>> for name, module in alexnet.named_children():
>>     print(name,':', module.__class__.__name__)
features : Sequential
classifier : Sequential
    def children(self):
        for name, module in self.named_children():
            yield module

    def named_children(self):
        memo = set()
        for name, module in self._modules.items():
            if module is not None and module not in memo:
                memo.add(module)
                yield name, module

modules()named_modules() 依次遍历树形结构的所有module(包括自身)。例如对于AlexNet,

>> for name, module in alexnet.named_modules(prefix='root'):
>>     print(name, ':', module.__class__.__name__)
root : AlexNet
root.features : Sequential
root.features.0 : Conv2d
root.features.1 : ReLU
root.features.2 : MaxPool2d
root.features.3 : Conv2d
root.features.4 : ReLU
root.features.5 : MaxPool2d
root.features.6 : Conv2d
root.features.7 : ReLU
root.features.8 : Conv2d
root.features.9 : ReLU
root.features.10 : Conv2d
root.features.11 : ReLU
root.features.12 : MaxPool2d
root.classifier : Sequential
root.classifier.0 : Dropout
root.classifier.1 : Linear
root.classifier.2 : ReLU
root.classifier.3 : Dropout
root.classifier.4 : Linear
root.classifier.5 : ReLU
root.classifier.6 : Linear

实现如下:


    def modules(self):
        for name, module in self.named_modules():
            yield module
            
    def named_modules(self, memo=None, prefix=''):
        if memo is None:
            memo = set()
        if self not in memo:
            memo.add(self)
            yield prefix, self
            for name, module in self._modules.items():
                if module is None:
                    continue
                submodule_prefix = prefix + ('.' if prefix else '') + name
                for m in module.named_modules(memo, submodule_prefix):
                    yield m

Parameters 成员的访问

parameters()named_parameters() 返回包含所有的网络参数成员的迭代器。每一个参数成员的名字通过"."的级联来表示层级关系。对于AlexNet来说:

>>>  for name, param in alexnet.named_parameters():
>>>     print(name, ':', param.__class__.__name__, ':', param.size())
features.0.weight : Parameter : (64L, 3L, 11L, 11L)
features.0.bias : Parameter : (64L,)
features.3.weight : Parameter : (192L, 64L, 5L, 5L)
features.3.bias : Parameter : (192L,)
features.6.weight : Parameter : (384L, 192L, 3L, 3L)
features.6.bias : Parameter : (384L,)
features.8.weight : Parameter : (256L, 384L, 3L, 3L)
features.8.bias : Parameter : (256L,)
features.10.weight : Parameter : (256L, 256L, 3L, 3L)
features.10.bias : Parameter : (256L,)
classifier.1.weight : Parameter : (4096L, 9216L)
classifier.1.bias : Parameter : (4096L,)
classifier.4.weight : Parameter : (4096L, 4096L)
classifier.4.bias : Parameter : (4096L,)
classifier.6.weight : Parameter : (1000L, 4096L)
classifier.6.bias : Parameter : (1000L,)

对于树形结构中的每个节点,不是所有的节点都有参数,只有一些节点才会有参数,通常这些节点是叶节点。

    def parameters(self):
        for name, param in self.named_parameters():
            yield param

    def named_parameters(self, memo=None, prefix=''):
        if memo is None:
            memo = set()
        for name, p in self._parameters.items():
            if p is not None and p not in memo:
                memo.add(p)
                yield prefix + ('.' if prefix else '') + name, p
        for mname, module in self.named_children():
            submodule_prefix = prefix + ('.' if prefix else '') + mname
            for name, p in module.named_parameters(memo, submodule_prefix):
                yield name, p

直接访问

__getattr__ 定义了直接通过属性名访问数据成员的操作。

    def __getattr__(self, name):
        if '_parameters' in self.__dict__:
            _parameters = self.__dict__['_parameters']
            if name in _parameters:
                return _parameters[name]
        if '_buffers' in self.__dict__:
            _buffers = self.__dict__['_buffers']
            if name in _buffers:
                return _buffers[name]
        if '_modules' in self.__dict__:
            modules = self.__dict__['_modules']
            if name in modules:
                return modules[name]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, name))

增加和删减数据成员

直接增加和删减

__setattr____delattr__


    def __setattr__(self, name, value):
        def remove_from(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]

        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules)
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError("cannot assign '{}' as parameter '{}' "
                                "(torch.nn.Parameter or None expected)"
                                .format(torch.typename(value), name))
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers)
                modules[name] = value
            elif modules is not None and name in modules:
                if value is not None:
                    raise TypeError("cannot assign '{}' as child module '{}' "
                                    "(torch.nn.Module or None expected)"
                                    .format(torch.typename(value), name))
                modules[name] = value
            else:
                buffers = self.__dict__.get('_buffers')
                if buffers is not None and name in buffers:
                    if value is not None and not torch.is_tensor(value):
                        raise TypeError("cannot assign '{}' as buffer '{}' "
                                        "(torch.Tensor or None expected)"
                                        .format(torch.typename(value), name))
                    buffers[name] = value
                else:
                    object.__setattr__(self, name, value)

    def __delattr__(self, name):
        if name in self._parameters:
            del self._parameters[name]
        elif name in self._buffers:
            del self._buffers[name]
        elif name in self._modules:
            del self._modules[name]
        else:
            object.__delattr__(self, name)

增加新的数据成员

对于树形结构中的节点,有些节点有参数(_paramters 的长度不为零)。我们可以通过register_buffer(name, tensor) 和 register_parameter(name, param) 来分别增加buffer和parameter。parameter是我们常见的数据,例如卷积层的卷积核。buffer的表现形式和parameter不同,因为它不会作为网络的参数进行学习,而仅仅作为一个module的属性,例如BathNormmeanvar

这里通过BatchNorm的实现来简要说明如何增加新的数据成员。其中BatchNorm的定义为:

y = x − m e a n [ x ] V a r [ x ] + ϵ ∗ g a m m a + b e t a y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta y=Var[x]+ϵ xmean[x]gamma+beta

class _BatchNorm(Module):

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
        super(_BatchNorm, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps
        self.momentum = momentum
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.reset_parameters()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)
        if self.affine:
            self.weight.data.uniform_()
            self.bias.data.zero_()

    def forward(self, input):
        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            self.training, self.momentum, self.eps)

    def __repr__(self):
        return ('{name}({num_features}, eps={eps}, momentum={momentum},'
                ' affine={affine})'
                .format(name=self.__class__.__name__, **self.__dict__))
    def register_buffer(self, name, tensor):
        if hasattr(self, name) and name not in self._buffers:
            raise KeyError("attribute '{}' already exists".format(name))

        self._buffers[name] = tensor


    def register_parameter(self, name, param):
        if '_parameters' not in self.__dict__:
            raise AttributeError(
                "cannot assign parameter before Module.__init__() call")

        if hasattr(self, name) and name not in self._parameters:
            raise KeyError("attribute '{}' already exists".format(name))

        if param is None:
            self._parameters[name] = None
        elif not isinstance(param, Parameter):
            raise TypeError("cannot assign '{}' object to parameter '{}' "
                            "(torch.nn.Parameter or None required)"
                            .format(torch.typename(param), name))
        elif param.grad_fn:
            raise ValueError(
                "Cannot assign non-leaf Variable to parameter '{0}'. Model "
                "parameters must be created explicitly. To express '{0}' "
                "as a function of another variable, compute the value in "
                "the forward() method.".format(name))
        else:
            self._parameters[name] = param

F.batch_norm 是nn.functional 中的一个定义,functional 和module的区别在于functional 仅仅定义了一种抽象的数值运算的规则,而module不仅仅有对应的数值运算的规则,还定义了需要的其他参数以及方法。

增加module

add_module 在当前树形结构中的根节点下增加一个子节点。

    def add_module(self, name, module):
        """Adds a child module to the current module.

        The module can be accessed as an attribute using the given name.

        Args:
            name (string): name of the child module. The child module can be
                accessed from this module using the given name
            parameter (Module): child module to be added to the module.
        """
        if not isinstance(module, Module) and module is not None:
            raise TypeError("{} is not a Module subclass".format(
                torch.typename(module)))
        if hasattr(self, name) and name not in self._modules:
            raise KeyError("attribute '{}' already exists".format(name))
        self._modules[name] = module

数据转换

数据的转换包括CPU和GPU之间的数据转换,数据类型的转换,以及拷贝到 share memory中。

    def _apply(self, fn):
        for module in self.children():
            module._apply(fn)

        for param in self._parameters.values():
            if param is not None:
                # Variables stored in modules are graph leaves, and we don't
                # want to create copy nodes, so we have to unpack the data.
                param.data = fn(param.data)
                if param._grad is not None:
                    param._grad.data = fn(param._grad.data)

        for key, buf in self._buffers.items():
            if buf is not None:
                self._buffers[key] = fn(buf)
        return self
    
	def cuda(self, device=None):
        return self._apply(lambda t: t.cuda(device))

    def cpu(self):
        return self._apply(lambda t: t.cpu())

    def type(self, dst_type):
        return self._apply(lambda t: t.type(dst_type))

    def float(self):
        return self._apply(lambda t: t.float())

    def double(self):
        return self._apply(lambda t: t.double())

    def half(self):
        return self._apply(lambda t: t.half())

    def share_memory(self):
        return self._apply(lambda t: t.share_memory_())

网络状态的切换

网络状态的切换仅仅会影响Dropout和BathNorm层的机制。

    def train(self, mode=True):
        """Sets the module in training mode.
    
        This has any effect only on modules such as Dropout or BatchNorm.
    
        Returns:
            Module: self
        """
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

    def eval(self):
        """Sets the module in evaluation mode.
    
        This has any effect only on modules such as Dropout or BatchNorm.
        """
        return self.train(False)

前向传递

网络的前向传递需要被重定义:

    def __call__(self, *input, **kwargs):
        for hook in self._forward_pre_hooks.values():
            hook(self, input)
        result = self.forward(*input, **kwargs)
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                raise RuntimeError(
                    "forward hooks should never return any values, but '{}'"
                    "didn't return None".format(hook))
        if len(self._backward_hooks) > 0:
            var = result
            while not isinstance(var, Variable):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, Variable)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in self._backward_hooks.values():
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result
    
    def forward(self, *input):
        raise NotImplementedError

Hook

    def register_backward_hook(self, hook):
        """Registers a backward hook on the module.

        The hook will be called every time the gradients with respect to module
        inputs are computed. The hook should have the following signature::

            hook(module, grad_input, grad_output) -> Tensor or None

        The :attr:`grad_input` and :attr:`grad_output` may be tuples if the
        module has multiple inputs or outputs. The hook should not modify its
        arguments, but it can optionally return a new gradient with respect to
        input that will be used in place of :attr:`grad_input` in subsequent
        computations.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
        """
        handle = hooks.RemovableHandle(self._backward_hooks)
        self._backward_hooks[handle.id] = hook
        return handle


    def register_forward_pre_hook(self, hook):
        """Registers a forward pre-hook on the module.

        The hook will be called every time before :func:`forward` is invoked.
        It should have the following signature::

            hook(module, input) -> None

        The hook should not modify the input.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
        """
        handle = hooks.RemovableHandle(self._forward_pre_hooks)
        self._forward_pre_hooks[handle.id] = hook
        return handle


    def register_forward_hook(self, hook):
        r"""Registers a forward hook on the module.

        The hook will be called every time after :func:`forward` has computed an output.
        It should have the following signature::

            hook(module, input, output) -> None

        The hook should not modify the input or output.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
        """
        handle = hooks.RemovableHandle(self._forward_hooks)
        self._forward_hooks[handle.id] = hook
        return handle

pickle相关

    def __setstate__(self, state):
        self.__dict__.update(state)
        if '_forward_pre_hooks' not in self.__dict__:
            self._forward_pre_hooks = OrderedDict()

网络参数的初始化

apply(fn)递归地对children()调用apply(fn)来对网络的参数进行初始化:

def appley(self, fn):
    for module in self.children():
        module.appley(fn)
    fn(self)
    return self

序列化

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        """Returns a dictionary containing a whole state of the module.

        Both parameters and persistent buffers (e.g. running averages) are
        included. Keys are corresponding parameter and buffer names.

        When keep_vars is ``True``, it returns a Variable for each parameter
        (rather than a Tensor).

        Args:
            destination (dict, optional):
                if not None, the return dictionary is stored into destination.
                Default: None
            prefix (string, optional): Adds a prefix to the key (name) of every
                parameter and buffer in the result dictionary. Default: ''
            keep_vars (bool, optional): if ``True``, returns a Variable for each
                parameter. If ``False``, returns a Tensor for each parameter.
                Default: ``False``

        Returns:
            dict:
                a dictionary containing a whole state of the module

        Example:
            >>> module.state_dict().keys()
            ['bias', 'weight']
        """
        if destination is None:
            destination = OrderedDict()
        for name, param in self._parameters.items():
            if param is not None:
                destination[prefix + name] = param if keep_vars else param.data
        for name, buf in self._buffers.items():
            if buf is not None:
                destination[prefix + name] = buf
        for name, module in self._modules.items():
            if module is not None:
                module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
        return destination


    def load_state_dict(self, state_dict, strict=True):
        """Copies parameters and buffers from :attr:`state_dict` into
        this module and its descendants. If :attr:`strict` is ``True`` then
        the keys of :attr:`state_dict` must exactly match the keys returned
        by this module's :func:`state_dict()` function.

        Arguments:
            state_dict (dict): A dict containing parameters and
                persistent buffers.
            strict (bool): Strictly enforce that the keys in :attr:`state_dict`
                match the keys returned by this module's `:func:`state_dict()`
                function.
        """
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, Parameter):
                    # backwards compatibility for serialized parameters
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    raise RuntimeError('While copying the parameter named {}, '
                                       'whose dimensions in the model are {} and '
                                       'whose dimensions in the checkpoint are {}.'
                                       .format(name, own_state[name].size(), param.size()))
            elif strict:
                raise KeyError('unexpected key "{}" in state_dict'
                               .format(name))
        if strict:
            missing = set(own_state.keys()) - set(state_dict.keys())
            if len(missing) > 0:
                raise KeyError('missing keys in state_dict: "{}"'.format(missing))

__dir____repr__

    def __dir__(self):
        module_attrs = dir(self.__class__)
        attrs = list(self.__dict__.keys())
        parameters = list(self._parameters.keys())
        modules = list(self._modules.keys())
        buffers = list(self._buffers.keys())
        keys = module_attrs + attrs + parameters + modules + buffers
        return sorted(keys)
    def __repr__(self):
        tmpstr = self.__class__.__name__ + '(\n'
        for key, module in self._modules.items():
            modstr = module.__repr__()
            modstr = _addindent(modstr, 2)
            tmpstr = tmpstr + '  (' + key + '): ' + modstr + '\n'
        tmpstr = tmpstr + ')'
        return tmpstr

zero grad

    def zero_grad(self):
        """Sets gradients of all model parameters to zero."""
        for p in self.parameters():
            if p.grad is not None:
                if p.grad.volatile:
                    p.grad.data.zero_()
                else:
                    data = p.grad.data
                    p.grad = Variable(data.new().resize_as_(data).zero_())

你可能感兴趣的:(Pytorch笔记)