pytorch中的 nn.ModuleList 和 nn.Sequential

nn.ModuleList() 和 nn.Sequential() 都可以用来搭建神经网络。nn.ModuleList()函数是用来存储各个模块,前后模块是没有关联的。

class net(nn.Module):
    def __init__(self):
        super(net6, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)])
 
    def forward(self, x):
        for layer in self.linears:
            x = layer(x)
        return x
 
net = net()
print(net)
# net(
#   (linears): ModuleList(
#     (0): Linear(in_features=10, out_features=10, bias=True)
#     (1): Linear(in_features=10, out_features=10, bias=True)
#     (2): Linear(in_features=10, out_features=10, bias=True)
#   )
# )

*可以将列表解包

class net(nn.Module):
    def __init__(self):
        super(net7, self).__init__()
        self.linear_list = [nn.Linear(10, 10) for i in range(3)]
        self.linears = nn.Sequential(*self.linear_list)  ###  *可以将列表解包
 
    def forward(se

你可能感兴趣的:(深度学习,python,pytorch,深度学习)