与Module, ModuleDict, Sequential同属于容器
和通常的python list一样进行append, extend, insert操作,但是参数会自动注册
'''from PYTORCH DOCUMENTATION'''
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
list需要手动注册参数, 如果不注册不会打印出任何结果
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
self.linears = [nn.Linear(4,4), nn.Linear(4,4), nn.Linear(4, 4)]
#self.linears = nn.ModuleList[nn.Linear(4,4), nn.Linear(4,4), nn.Linear(4, 4)]
#self.linears = nn.ModuleList([nn.Linear(4,4) for i in range(3)])
'''手动注册参数'''
for i, layer in enumerate(self.linears):
layer.weight = nn.Parameter(torch.rand(4, 4))
self.register_parameter('weight'+str(i), layer.weight)
def forward(self, x):
for linear in self.linears:
x = linear(x)
x = F.relu(x)
return x
net = Net()
for parameter in net.parameters():
print(parameter)
ModuleList中的模块顺序与实际网络中数据流动的顺序无关,取决于forward中的定义
参考PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景 - 知乎 (zhihu.com)
nn.Sequential与nn.ModuleList-CSDN博客
torch.nn — PyTorch master documentation
PyTorch之Container深度理解 - 知乎 (zhihu.com)
Pytorch参数注册和nn.ModuleList nn.ModuleDict的问题_python_脚本之家 (jb51.net)
register_parameter和register_buffer 详解-CSDN博客
PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景 - 知乎 (zhihu.com)