为方便介绍,这里先通过继承 nn.Module 来定义一个Net网络。
import torch
from torch import nn
from torch.nn import functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.flatten = nn.Flatten()
self.linear1 = nn.Linear(1024, 512)
self.linear2 = nn.Linear(512, 128)
self.linear3 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.flatten(x)
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
net = Model()
在实例化自定义的网络时,会执行构造函数 __init__()
,为了是后续的操作正确执行,需要在定义类时在构造函数 __init__()
中首先调用执行父类的 __init__()
函数:super(Model, self).__init__()
,这样才会创造出Net必须有的字典属性:
_parameters
_modules
_buffers
_backward_hooks
_forward_hooks
_forward_pre_hooks
_state_dict_hooks
_load_state_dict_pre_hooks
_parameters
, _modules
, _buffers
nn.Module
使用了 Python 的 __setattr__
机制,当在类中定义成员时,__setattr__
会检测成员的 type 派生于哪些类型。如果派生于 Parameter
类,则被归于 _parameters
;如果派生于 Module
,则划归于 _modules
。因此,如果类中定义的成员被封装到Python的普通数据类型中,则不会自动归类,比如:self.layers = [nn.Linear(1024, 80), nn.Linear(80, 10]
,检测到是list类型,则会视为普通属性。
_parameters
当直接调用 net._parameters
时,会发现,字典为空。因为在定义的网络的成员没有直接派生于 Parameter
类的,所以该方法返回空字典。这时可以使用 net.parameters()
方法,该方法返回一个迭代器,递归获取每层的参数。
>>> net._parameters
OrderedDict()
>>> for i in net.parameters():
print(i.__class__)
break
<class 'torch.nn.parameter.Parameter'>
_modules
_modules
包含了类所有的派生于 Module
的成员,前面说了,如果成员被封装到列表中,并不会被添加到 _modules
中,这时,如有必要,可以使用 ModuleList
替代列表来使用,ModuleList
继承了 类,实现了list的功能。
>>> net._modules
OrderedDict([('conv1', Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))),
('pool',
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)),
('conv2', Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))),
('flatten', Flatten(start_dim=1, end_dim=-1)),
('linear1', Linear(in_features=1024, out_features=80, bias=True)),
('linear2', Linear(in_features=80, out_features=10, bias=True))])
_buffers
该成员值的填充是通过register_buffer API来完成的,通常用来将一些需要持久化的状态(但又不是网络的参数)放到_buffer里;一些极其个别的操作,比如BN,会将running_mean的值放入进来。
_backward_hooks
, _forward_hooks
, _forward_pre_hooks
hook 函数可以在不改变网络的主体的情况下,实现一些额外的功能。因为 Pytorch 中动态运算图的机制,网络计算的中间变量会在计算结束后释放以节省性能。此时,可以通过在某一层网络上挂上一些 hook 函数来获取该层的中间变量。
hook 函数只需要是一个可调用对象就可以(实现了 __call__
)。PyTorch 为网络提供了三种 hook:
forward_pre_hooks
:nn.Module
提供了 net.register_forward_pre_hook(hook)
方法来注册该 hook:hook(module, input)
。该 hook 函数用于获取网络层的 input。
forward_hooks
:nn.Module
提供了 net.register_forward_hook(hook)
方法来注册该 hook:hook(module, input, output)
。该 hook 函数用于获取 module 的 input 和 output。
backward_hooks
:nn.Module
提供了 net.register_backward_hook(hook)
方法来注册该hook:hook(module, grad_input, grad_output)
。该 hook 函数用于获取反向传播中 module 的grad_in,grad_out。
net = Model()
x = torch.rand(10, 1, 28, 28)
y = net(x)
x = x + 0.1*torch.rand(x.shape)
loss = torch.nn.L1Loss()
def forward_pre_hooks(module, input):
r"""前向传播前hook函数"""
print("forward_pre_hooks 的输出:")
print("module: ", module, "\ninput shape: ", input[0].shape)
def forward_hooks(module, input, output):
r"""前向传播hook函数"""
print("\n\nforward_hooks 的输出:")
print("module: ", module, "\ninput shape: ", input[0].shape, "\noutput shape: ", output[0].shape)
def backward_hooks(module, grad_input, grad_output):
r"""反向传播hook函数"""
print("\n\nbackward_hooks 的输出:")
print("module: ", module, "\ngrad_input shape: ", grad_input[0].shape, "\ngrad_output shape: ", grad_output[0].shape)
net.conv2.register_forward_pre_hook(forward_pre_hooks)
net.conv2.register_forward_hook(forward_hooks)
net.conv2.register_backward_hook(backward_hooks)
y_hat = net(x)
l = loss(y_hat, y)
l.sum().backward()
输出:
forward_pre_hooks 的输出:
module: Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
input shape: torch.Size([10, 6, 12, 12])
forward_hooks 的输出:
module: Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
input shape: torch.Size([10, 6, 12, 12])
output shape: torch.Size([16, 8, 8])
backward_hooks 的输出:
module: Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
grad_input shape: torch.Size([10, 6, 12, 12])
grad_output shape: torch.Size([10, 16, 8, 8])
_backward_hooks
, _forward_hooks
, _forward_pre_hooks
中存放了对应的hook。在网络进行前向传播时:首先会执行 _forward_pre_hooks
中的 hooks,然后执行网络的 forward 函数;再然后执行 _forward_hooks
中的 hooks 函数。当发生反向传播时,会依次执行 _backward_hooks
中的 hooks。
>>> net.conv2._forward_pre_hooks
OrderedDict([(43, <function __main__.forward_pre_hooks(module, input)>)])
>>> net.conv2._forward_hooks
OrderedDict([(44, <function __main__.forward_hooks(module, input, output)>)])
>>> net.conv2._backward_hooks
OrderedDict([(45,
<function __main__.backward_hooks(module, grad_input, grad_output)>)])
_state_dict_hooks
, _load_state_dict_pre_hooks