torch.nn.ModuleList()

torch.nn.ModuleList()

与Module, ModuleDict, Sequential同属于容器

和通常的python list一样进行append, extend, insert操作,但是参数会自动注册

示例

'''from PYTORCH DOCUMENTATION'''
class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

和List的区别

list需要手动注册参数, 如果不注册不会打印出任何结果

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.linears = [nn.Linear(4,4), nn.Linear(4,4), nn.Linear(4, 4)]
        #self.linears = nn.ModuleList[nn.Linear(4,4), nn.Linear(4,4), nn.Linear(4, 4)]
        #self.linears = nn.ModuleList([nn.Linear(4,4) for i in range(3)]) 
        '''手动注册参数'''
        for i, layer in enumerate(self.linears):
            layer.weight = nn.Parameter(torch.rand(4, 4))
            self.register_parameter('weight'+str(i), layer.weight)
    def forward(self, x):
        for linear in self.linears:
            x = linear(x)
            x = F.relu(x)
        return x

net = Net()
for parameter in net.parameters():
                print(parameter)

和Sequential的区别

ModuleList中的模块顺序与实际网络中数据流动的顺序无关,取决于forward中的定义
参考PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景 - 知乎 (zhihu.com)

Related Links

nn.Sequential与nn.ModuleList-CSDN博客

torch.nn — PyTorch master documentation

PyTorch之Container深度理解 - 知乎 (zhihu.com)

Pytorch参数注册和nn.ModuleList nn.ModuleDict的问题_python_脚本之家 (jb51.net)
register_parameter和register_buffer 详解-CSDN博客

PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景 - 知乎 (zhihu.com)

你可能感兴趣的:(pytorch)