pytorch nn.ModuleList() 和nn.Sequential()

在遇到一些稍微复杂的搭建模型的需求的时候,使用pytorch中的 nn.ModuleList() 和nn.Sequential()可以方便很多。

使用 ModuleList 可以简化写法。

这里需要讲的是,ModuleList 可以存储多个 model,传统的方法,一个model 就要写一个 forward ,但是如果将它们存到一个 ModuleList 的话,就可以使用一个 forward。
ModuleList是Module的子类,当在Module中使用它的时候,就能自动识别为子module。
当添加 nn.ModuleList 作为 nn.Module 对象的一个成员时(即当我们添加模块到我们的网络时),所有 nn.ModuleList 内部的 nn.Module 的 parameter 也被添加作为 我们的网络的 parameter。

class model2(nn.Module):
    def __init__(self):
        super(model2, self).__init__()
        self.layers=nn.ModuleList([
            nn.Linear(1,10), nn.ReLU(),
            nn.Linear(10,100),nn.ReLU(),
            nn.Linear(100,10),nn.ReLU(),
            nn.Linear(10,1)])
    def forward(self,x):
        out=x
        for i,layer in enumerate(self.layers):
            out=layer(out)
        return out

其它用法
ModuleList 具有和List 相似的用法,实际上可以把它视作是 Module 和 list 的结合。
除了在创建 ModuleList 的时候传入一个 module 的 列表,还可以使用extend 函数和 append 函数来添加模型。
1.extend 方法
和 list 相似,参数为一个元素为 Module的列表,该方法的效果是将列表中的所有 Module 添加到 ModuleList中:

self.linears.extend([nn.Linear(size1, size2) for i in range(1, num_layers)])

2.append 方法
和list 的append 方法一样,将 一个 Module 添加到ModuleList。

self.linears.append(nn.Linear(size1, size2)

使用 nn.Sequential()

class model3(nn.Module):
    def __init__(self):
        super(model3, self).__init__()
        self.network=nn.Sequential(
            nn.Linear(1,10),nn.ReLU(),
            nn.Linear(10,100),nn.ReLU(),
            nn.Linear(100,10),nn.ReLU(),
            nn.Linear(10,1)
        )
    def forward(self, x):
        return self.network(x)

你可能感兴趣的:(pytorch nn.ModuleList() 和nn.Sequential())