Module
作为所有神经网络模块的基类. 新的Module
也应该继承这个类。一个Module
还可以包含其他Module
,允许将它们嵌套在树形结构中。另外可以将子Module
分配为常规的类属性(attribute)。我们以AlexNet作为一个简单的例子,结合Module
的定义来说明。
import torch.nn as nn
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return x
每一个Module
都有各自的_parameters
,_buffers
,_modules
等参数。
class Module(object):
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
对于一个Module
类,._modules
是一个有序字典, 每个值也是一个字典. 所以Module
通过递归地定义不同的children modules来定义整个网络的树形结构。每一个子Module
都是Module
类. 下面的结构是子模型的树形结构.
对于AlexNet来说,它的树形结构如下:
graph LR
AlexNet["AlexNet"]
features["'features':features"]
classifier["'classifier':classifier"]
AlexNet -- "_modules.values()[0]" --> features
AlexNet -- "_modules.values()[1]" --> classifier
subgraph AlexNeat
AlexNet
subgraph Sequential
classifier --> classifier.0["'0':Dropout"]
classifier --> classifier.1["'1':Linear"]
classifier --> classifier.2["'2':ReLU"]
classifier --> classifier.3["'3':Dropout"]
classifier --> classifier.4["'4':Linear"]
classifier --> classifier.5["'5':ReLU"]
classifier --> classifier.6["'6':Linear"]
end
subgraph Sequential
features --> features.0["'0':Conv2d"]
features --> features.1["'1':ReLU"]
features --> features.2["2':MaxPool2d"]
features --> features.3["'3':Conv2d"]
features --> features.4["'4':ReLU"]
features --> features.5["'5':MaxPool2d"]
features --> features.6["'6':Conv2d"]
features --> features.7["'7':ReLU"]
features --> features.8["'8':Conv2d"]
features --> features.9["'9':ReLU"]
features --> features.10["'10':Conv2d"]
features --> features.11["'11':ReLU"]
features --> features.12["'12':MaxPool2d"]
end
end
主要的数据成员有_parameters
,_buffers
,_modules
这三个成员的类型都是OrderedDict
类型
每一层module都有各自的参数,而完整module的参数是所有子module的集合。因此,访问参数的实现通过递归地访问返回每个子module的参数。
def _all_buffers(self, memo=None):
if memo is None:
memo = set()
for name, b in self._buffers.items():
if b is not None and b not in memo:
memo.add(b)
yield b
for module in self.children():
for b in module._all_buffers(memo):
yield b
children()
和named_children()
返回当前module的子module.对于AlexNet返回的是两个子module
>> for child in alexnet.children():
>> print(child.__class__.__name__)
Sequential
Sequential
>> for name, module in alexnet.named_children():
>> print(name,':', module.__class__.__name__)
features : Sequential
classifier : Sequential
def children(self):
for name, module in self.named_children():
yield module
def named_children(self):
memo = set()
for name, module in self._modules.items():
if module is not None and module not in memo:
memo.add(module)
yield name, module
modules()
和named_modules()
依次遍历树形结构的所有module(包括自身)。例如对于AlexNet,
>> for name, module in alexnet.named_modules(prefix='root'):
>> print(name, ':', module.__class__.__name__)
root : AlexNet
root.features : Sequential
root.features.0 : Conv2d
root.features.1 : ReLU
root.features.2 : MaxPool2d
root.features.3 : Conv2d
root.features.4 : ReLU
root.features.5 : MaxPool2d
root.features.6 : Conv2d
root.features.7 : ReLU
root.features.8 : Conv2d
root.features.9 : ReLU
root.features.10 : Conv2d
root.features.11 : ReLU
root.features.12 : MaxPool2d
root.classifier : Sequential
root.classifier.0 : Dropout
root.classifier.1 : Linear
root.classifier.2 : ReLU
root.classifier.3 : Dropout
root.classifier.4 : Linear
root.classifier.5 : ReLU
root.classifier.6 : Linear
实现如下:
def modules(self):
for name, module in self.named_modules():
yield module
def named_modules(self, memo=None, prefix=''):
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in module.named_modules(memo, submodule_prefix):
yield m
parameters()
和named_parameters()
返回包含所有的网络参数成员的迭代器。每一个参数成员的名字通过"."的级联来表示层级关系。对于AlexNet来说:
>>> for name, param in alexnet.named_parameters():
>>> print(name, ':', param.__class__.__name__, ':', param.size())
features.0.weight : Parameter : (64L, 3L, 11L, 11L)
features.0.bias : Parameter : (64L,)
features.3.weight : Parameter : (192L, 64L, 5L, 5L)
features.3.bias : Parameter : (192L,)
features.6.weight : Parameter : (384L, 192L, 3L, 3L)
features.6.bias : Parameter : (384L,)
features.8.weight : Parameter : (256L, 384L, 3L, 3L)
features.8.bias : Parameter : (256L,)
features.10.weight : Parameter : (256L, 256L, 3L, 3L)
features.10.bias : Parameter : (256L,)
classifier.1.weight : Parameter : (4096L, 9216L)
classifier.1.bias : Parameter : (4096L,)
classifier.4.weight : Parameter : (4096L, 4096L)
classifier.4.bias : Parameter : (4096L,)
classifier.6.weight : Parameter : (1000L, 4096L)
classifier.6.bias : Parameter : (1000L,)
对于树形结构中的每个节点,不是所有的节点都有参数,只有一些节点才会有参数,通常这些节点是叶节点。
def parameters(self):
for name, param in self.named_parameters():
yield param
def named_parameters(self, memo=None, prefix=''):
if memo is None:
memo = set()
for name, p in self._parameters.items():
if p is not None and p not in memo:
memo.add(p)
yield prefix + ('.' if prefix else '') + name, p
for mname, module in self.named_children():
submodule_prefix = prefix + ('.' if prefix else '') + mname
for name, p in module.named_parameters(memo, submodule_prefix):
yield name, p
__getattr__
定义了直接通过属性名访问数据成员的操作。
def __getattr__(self, name):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
return _parameters[name]
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name]
if '_modules' in self.__dict__:
modules = self.__dict__['_modules']
if name in modules:
return modules[name]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))
__setattr__
和 __delattr__
def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not torch.is_tensor(value):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)"
.format(torch.typename(value), name))
buffers[name] = value
else:
object.__setattr__(self, name, value)
def __delattr__(self, name):
if name in self._parameters:
del self._parameters[name]
elif name in self._buffers:
del self._buffers[name]
elif name in self._modules:
del self._modules[name]
else:
object.__delattr__(self, name)
对于树形结构中的节点,有些节点有参数(_paramters
的长度不为零)。我们可以通过register_buffer
(name, tensor) 和 register_parameter(name, param)
来分别增加buffer和parameter。parameter是我们常见的数据,例如卷积层的卷积核。buffer的表现形式和parameter不同,因为它不会作为网络的参数进行学习,而仅仅作为一个module的属性,例如BathNorm
的mean
和var
。
这里通过BatchNorm
的实现来简要说明如何增加新的数据成员。其中BatchNorm
的定义为:
y = x − m e a n [ x ] V a r [ x ] + ϵ ∗ g a m m a + b e t a y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta y=Var[x]+ϵx−mean[x]∗gamma+beta
class _BatchNorm(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super(_BatchNorm, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
self.momentum = momentum
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()
def forward(self, input):
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
def __repr__(self):
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
' affine={affine})'
.format(name=self.__class__.__name__, **self.__dict__))
def register_buffer(self, name, tensor):
if hasattr(self, name) and name not in self._buffers:
raise KeyError("attribute '{}' already exists".format(name))
self._buffers[name] = tensor
def register_parameter(self, name, param):
if '_parameters' not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
if hasattr(self, name) and name not in self._parameters:
raise KeyError("attribute '{}' already exists".format(name))
if param is None:
self._parameters[name] = None
elif not isinstance(param, Parameter):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"(torch.nn.Parameter or None required)"
.format(torch.typename(param), name))
elif param.grad_fn:
raise ValueError(
"Cannot assign non-leaf Variable to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another variable, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param
F.batch_norm
是nn.functional 中的一个定义,functional 和module的区别在于functional 仅仅定义了一种抽象的数值运算的规则,而module不仅仅有对应的数值运算的规则,还定义了需要的其他参数以及方法。
add_module
在当前树形结构中的根节点下增加一个子节点。
def add_module(self, name, module):
"""Adds a child module to the current module.
The module can be accessed as an attribute using the given name.
Args:
name (string): name of the child module. The child module can be
accessed from this module using the given name
parameter (Module): child module to be added to the module.
"""
if not isinstance(module, Module) and module is not None:
raise TypeError("{} is not a Module subclass".format(
torch.typename(module)))
if hasattr(self, name) and name not in self._modules:
raise KeyError("attribute '{}' already exists".format(name))
self._modules[name] = module
数据的转换包括CPU和GPU之间的数据转换,数据类型的转换,以及拷贝到 share memory中。
def _apply(self, fn):
for module in self.children():
module._apply(fn)
for param in self._parameters.values():
if param is not None:
# Variables stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param.data = fn(param.data)
if param._grad is not None:
param._grad.data = fn(param._grad.data)
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
def cuda(self, device=None):
return self._apply(lambda t: t.cuda(device))
def cpu(self):
return self._apply(lambda t: t.cpu())
def type(self, dst_type):
return self._apply(lambda t: t.type(dst_type))
def float(self):
return self._apply(lambda t: t.float())
def double(self):
return self._apply(lambda t: t.double())
def half(self):
return self._apply(lambda t: t.half())
def share_memory(self):
return self._apply(lambda t: t.share_memory_())
网络状态的切换仅仅会影响Dropout和BathNorm层的机制。
def train(self, mode=True):
"""Sets the module in training mode.
This has any effect only on modules such as Dropout or BatchNorm.
Returns:
Module: self
"""
self.training = mode
for module in self.children():
module.train(mode)
return self
def eval(self):
"""Sets the module in evaluation mode.
This has any effect only on modules such as Dropout or BatchNorm.
"""
return self.train(False)
网络的前向传递需要被重定义:
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
hook(self, input)
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
raise RuntimeError(
"forward hooks should never return any values, but '{}'"
"didn't return None".format(hook))
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, Variable):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, Variable)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
def forward(self, *input):
raise NotImplementedError
def register_backward_hook(self, hook):
"""Registers a backward hook on the module.
The hook will be called every time the gradients with respect to module
inputs are computed. The hook should have the following signature::
hook(module, grad_input, grad_output) -> Tensor or None
The :attr:`grad_input` and :attr:`grad_output` may be tuples if the
module has multiple inputs or outputs. The hook should not modify its
arguments, but it can optionally return a new gradient with respect to
input that will be used in place of :attr:`grad_input` in subsequent
computations.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
def register_forward_pre_hook(self, hook):
"""Registers a forward pre-hook on the module.
The hook will be called every time before :func:`forward` is invoked.
It should have the following signature::
hook(module, input) -> None
The hook should not modify the input.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._forward_pre_hooks)
self._forward_pre_hooks[handle.id] = hook
return handle
def register_forward_hook(self, hook):
r"""Registers a forward hook on the module.
The hook will be called every time after :func:`forward` has computed an output.
It should have the following signature::
hook(module, input, output) -> None
The hook should not modify the input or output.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._forward_hooks)
self._forward_hooks[handle.id] = hook
return handle
def __setstate__(self, state):
self.__dict__.update(state)
if '_forward_pre_hooks' not in self.__dict__:
self._forward_pre_hooks = OrderedDict()
apply(fn)
递归地对children()
调用apply(fn)
来对网络的参数进行初始化:
def appley(self, fn):
for module in self.children():
module.appley(fn)
fn(self)
return self
def state_dict(self, destination=None, prefix='', keep_vars=False):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
When keep_vars is ``True``, it returns a Variable for each parameter
(rather than a Tensor).
Args:
destination (dict, optional):
if not None, the return dictionary is stored into destination.
Default: None
prefix (string, optional): Adds a prefix to the key (name) of every
parameter and buffer in the result dictionary. Default: ''
keep_vars (bool, optional): if ``True``, returns a Variable for each
parameter. If ``False``, returns a Tensor for each parameter.
Default: ``False``
Returns:
dict:
a dictionary containing a whole state of the module
Example:
>>> module.state_dict().keys()
['bias', 'weight']
"""
if destination is None:
destination = OrderedDict()
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.data
for name, buf in self._buffers.items():
if buf is not None:
destination[prefix + name] = buf
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
return destination
def load_state_dict(self, state_dict, strict=True):
"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True`` then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :func:`state_dict()` function.
Arguments:
state_dict (dict): A dict containing parameters and
persistent buffers.
strict (bool): Strictly enforce that the keys in :attr:`state_dict`
match the keys returned by this module's `:func:`state_dict()`
function.
"""
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].copy_(param)
except Exception:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if strict:
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
__dir__
和 __repr__
def __dir__(self):
module_attrs = dir(self.__class__)
attrs = list(self.__dict__.keys())
parameters = list(self._parameters.keys())
modules = list(self._modules.keys())
buffers = list(self._buffers.keys())
keys = module_attrs + attrs + parameters + modules + buffers
return sorted(keys)
def __repr__(self):
tmpstr = self.__class__.__name__ + '(\n'
for key, module in self._modules.items():
modstr = module.__repr__()
modstr = _addindent(modstr, 2)
tmpstr = tmpstr + ' (' + key + '): ' + modstr + '\n'
tmpstr = tmpstr + ')'
return tmpstr
def zero_grad(self):
"""Sets gradients of all model parameters to zero."""
for p in self.parameters():
if p.grad is not None:
if p.grad.volatile:
p.grad.data.zero_()
else:
data = p.grad.data
p.grad = Variable(data.new().resize_as_(data).zero_())