class MyModel1(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.flatten = nn.Flatten()
self.linears = [nn.Linear(1024, 512), nn.Linear(512, 128), nn.Linear(128, 10)]
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.flatten(x)
x = F.relu(self.linears[0](x))
x = F.relu(self.linears[1](x))
x = self.linears[2](x)
return x
在继承 nn.Modules
来构造自定义 module 时,会通过父类的__setattr__
来识别 module 的成员属性是否为 nn.Modules
的子类来将其注册为模块的子模块。但是,如果子模块没用直接赋值给对象属性,而是通过 Python 的 list 封装的,那么子模块就不会被注册。
>>> net = MyModel1()
>>> net
MyModel1(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(flatten): Flatten(start_dim=1, end_dim=-1)
)
nn.ModuleList
为了解决上述的问题,PyTorch 中定义了 nn.ModuleList
对象。nn.ModuleList
也继承 nn.Module
类,并且实现了和 Python 内置list 相同的功能。
class ModuleList(nn.Module):
# ModuleList的特殊方法。
def __init__(self, modules):
r"""
parameter:
modules: iterable, optional, 要添加的modules的迭代器
"""
pass
def append(self, module):
r"""
将module添加到列表的最后。
parameter:
modules: iterable, optional, 要添加的module
"""
pass
def extend(self, modules):
r"""
将从modules的迭代器添加modules到列表的最后。
parameter:
modules: iterable, optional, 要添加的modules的迭代器
"""
pass
def insert(self, index, module):
r"""
parameter:
index: int, 插入的位置索引
modules: iterable, optional, 要添加的module
"""
pass
nn.ModuleList
可以被识别添加为定义模块的子模块,extend、append 和 insert 方法和 list 中的用法类似。
class MyModel2(nn.Module):
def __init__(self):
super(MyModel2, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.flatten = nn.Flatten()
self.linears = nn.ModuleList([nn.Linear(1024, 512)])
self.linears.extend([nn.Linear(512, 128)])
self.linears.append(nn.Linear(128, 10))
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.flatten(x)
x = F.relu(self.linears[0](x))
x = F.relu(self.linears[1](x))
x = self.linears[2](x)
return x
net = MyModel2()
net
输出:
MyModel2(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(flatten): Flatten(start_dim=1, end_dim=-1)
(linears): ModuleList(
(0): Linear(in_features=1024, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=128, bias=True)
(2): Linear(in_features=128, out_features=10, bias=True)
)
)
nn.ParameterList
使用和 nn.ModuleList
类似,可以将 nn.Parameter
添加为模块的参数。
nn.ModuleDict
class ModuleDict(nn.Module):
def __init__(self, modules):
r"""
parameter:
modules: iterable, optional, string:module的字典,或(string, module)的迭代器
"""
pass
def clear(self):
r"""移除ModuleDict中的全部item"""
pass
def items(self):
r"""返回一个key/value对的迭代器"""
pass
def keys(self):
r"""返回一个key的迭代器"""
pass
def values(self):
r"""返回一个ModuleDict值的迭代器"""
pass
def pop(self, key):
r"""从ModuleDict中移除key,并返回对应的module"""
pass
def update(self, modules):
r"""
更新dict
parameters:
modules: iterable, string 和 module的映射类型
"""
pass
定义自定义网络:
class MyModel3(nn.Module):
def __init__(self):
super(MyModel3, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.flatten = nn.Flatten()
self.linears = nn.ModuleDict({'linear1': nn.Linear(1024, 512),
'linear2': nn.Linear(512, 128)})
self.linears.update({'linear3': nn.Linear(128, 10)})
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.flatten(x)
x = F.relu(self.linears['linear1'](x))
x = F.relu(self.linears['linear2'](x))
x = self.linears['linear3'](x)
return x
net = MyModel3()
print(net)
输出:
MyModel2(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(flatten): Flatten(start_dim=1, end_dim=-1)
(linears): ModuleDict(
(linear1): Linear(in_features=1024, out_features=512, bias=True)
(linear2): Linear(in_features=512, out_features=128, bias=True)
(linear3): Linear(in_features=128, out_features=10, bias=True)
)
)
nn.Sequential
nn.Sequential
通用继承了 nn.Module
类,与Module不同的是,Sequential 已经默认定义了forward函数,按照顺序依次输入输出。
例1:
class MyModel4(nn.Module):
def __init__(self):
super(MyModel4, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.flatten = nn.Flatten()
self.linears = nn.Sequential(nn.Linear(1024, 512),
nn.Linear(512, 128),
nn.Linear(128, 10))
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.flatten(x)
x = F.relu(self.linears['linear1'](x))
x = F.relu(self.linears['linear2'](x))
x = self.linears['linear3'](x)
return x
net = MyModel4()
print(net)
输出:
MyModel4(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(flatten): Flatten(start_dim=1, end_dim=-1)
(linears): Sequential(
(0): Linear(in_features=1024, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=128, bias=True)
(2): Linear(in_features=128, out_features=10, bias=True)
)
)
例2:
class MyModel5(nn.Module):
def __init__(self):
super(MyModel5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.flatten = nn.Flatten()
self.linears = nn.Sequential(OrderedDict(
[('1', nn.Linear(1024, 512)),
('2', nn.Linear(512, 128)),
('3', nn.Linear(128, 10))])
)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.flatten(x)
x = F.relu(self.linears['linear1'](x))
x = F.relu(self.linears['linear2'](x))
x = self.linears['linear3'](x)
return x
net = MyModel5()
print(net)
输出:
MyModel5(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(flatten): Flatten(start_dim=1, end_dim=-1)
(linears): Sequential(
(1): Linear(in_features=1024, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=128, bias=True)
(3): Linear(in_features=128, out_features=10, bias=True)
)
)