nn.Module是使用pytorch进行神经网络训练的主要载体,是所有网络的基类。首先看一下它的构造函数:
def __init__(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._modules = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
下面分类来看nn.Module中的其他非辅助方法:
cuda:将所有的parameters和buffers移动到gpu
cpu:将所有的parameters和buffers移动到cpu
type:将所有的parameters和buffers都转换为指定的目标类型
float:将所有的parameters和buffers都转换为float类型
double:将所有的parameters和buffers都转换为double类型
half:将所有的parameters和buffers都转换为float16类型
to:该函数有四种用法:
上述所有函数的功能均借助_apply完成:
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def compute_should_use_set_data(tensor, tensor_applied):
# 是否进行in-place操作
for key, param in self._parameters.items():
if param is not None:
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `param_applied`, so we have to use
# `with torch.no_grad():`
with torch.no_grad():
param_applied = fn(param)
should_use_set_data = compute_should_use_set_data(param, param_applied)
if should_use_set_data:
# 直接替换旧数据
param.data = param_applied
else:
assert isinstance(param, Parameter)
assert param.is_leaf
# 注册新的Parameter
self._parameters[key] = Parameter(param_applied, param.requires_grad)
# 对param.grad进行相同的操作
if param.grad is not None:
with torch.no_grad():
grad_applied = fn(param.grad)
should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
if should_use_set_data:
param.grad.data = grad_applied
else:
assert param.grad.is_leaf
self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)
# 更新_buffers
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
_apply函数遍历了所有的parameters(如果有grad还要遍历grad)和buffers,对它们应用fn,并将fn作用的结果注册到相应的容器中。
register_backward_hook:向self._backward_hooks注册新元素
register_forward_pre_hook:向self._forward_pre_hooks注册新元素
register_forward_hook:向self._forward_hooks注册新元素
关于hook技术的详细介绍可参考我的里一篇文章pytorch的自动求导和hook技术简介
state_dict:返回一个包含模型所有parameters和buffers的字典。需要注意的是,所有的参数在被添加到字典中之前,都要经过self._state_dict_hooks中的hook函数处理。
该函数是个递归函数,每次递归调用了辅助函数_save_to_state_dict来将当前module的子module参数添加到字典中。
load_state_dict:将指定的state_dict中的参数加载到当前的module中。需要注意的是,所有的参数在被添加到当前模型的_parameters和 _ buffers之前,都要经过self. _load_state_dict_pre_hooks中hook函数处理。
该函数中定义了一个递归函数load。load每次递归时调用辅助函数_load_from_state_dict来添加当前module的子module参数。
该函数有一个输入参数strict,默认为True。strict为True时,会返回一个namedtuple。该tuple有两个属性:missing_keys和unexpected_keys。missing_keys是存储当前module有而待加载的state_dict中没有的参数的列表;unexpected_keys是存储当前module没有而待加载的state_dict中有的参数的列表。
named_modules:返回模型中的所有module(包括模型本身)及其名称。返回值中有两个值,第一个为module的名(string格式),第二个为相应的module。返回顺序是自顶向下。
modules:返回模型中的所有module,调用named_modules完成。
named_children:返回当前模型_modules中的所有元素及其名称。注意该函数与named_modules的区别:named_modules是递归函数,每次递归均查询当前module的 _modules,进而能够遍历模型中的所有module;而named_chilldren只查询整个模型的 _modules
children:返回当前模型_modules中的所有元素,调用named_modules来完成。
named_parameters:返回模型中所有的可训练参数及其名称。
parameters:返回模型中所有的可训练参数。
named_buffers:返回模型中所有的非训练参数及其名称。
buffers:返回模型中所有的非训练参数。
上述所有的方法返回的都是由yield定义的生成器。
前两个方法借助named_modules来完成。
def named_modules(self, memo=None, prefix=''):
r"""Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
Yields:
(string, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
"""
if memo is None:
memo = set()
if self not in memo:
memo.add(self) # 包括当前模型自身
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
# 递归返回每一级所有module
for m in module.named_modules(memo, submodule_prefix):
yield m
named_children和children直接返回_modules中的所有元素。
最后四个方法调用了辅助函数_named_members来实现:
def _named_members(self, get_members_fn, prefix='', recurse=True):
r"""Helper method for yielding various names + members of modules."""
memo = set()
# 调用named_modules获得模型中的所有module
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
members = get_members_fn(module) # 获得当前module的所有参数
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield name, v
用一个简单的示例来学习他们的区别:
import torch
import torch.nn as nn
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
)
self.fc = nn.Linear(32, 3)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 32)
x = self.fc(x)
return x
model = net()
for k, m in model._modules.items():
print(k, m)
# conv Sequential(
# (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
# (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU(inplace=True)
# )
# fc Linear(in_features=32, out_features=3, bias=True)
for k, v in model.named_modules():
print(k, v)
# net(
# (conv): Sequential(
# (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
# (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU(inplace=True)
# )
# (fc): Linear(in_features=32, out_features=3, bias=True)
# )
# conv Sequential(
# (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
# (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU(inplace=True)
# )
# conv.0 Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
# conv.1 BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# conv.2 ReLU(inplace=True)
# fc Linear(in_features=32, out_features=3, bias=True)
for k, v in model.named_children():
print(k, v)
# conv Sequential(
# (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
# (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU(inplace=True)
# )
# fc Linear(in_features=32, out_features=3, bias=True)
# 注意children()和modules()的区别:children只返回一级子module
for k, _ in model.named_parameters():
print(k)
# conv.0.weight
# conv.1.weight
# conv.1.bias
# fc.weight
# fc.bias
for k, _ in model.named_buffers():
print(k)
# conv.1.running_mean
# conv.1.running_var
# conv.1.num_batches_tracked
apply:将自定义函数作用于模型的所有子module。
def apply(self, fn):
for module in self.children():
module.apply(fn)
fn(self)
return self