Pytorch:torch.nn.ModuleList()、torch.nn.Sequential()

PyTorch 中有一些基础概念在构建网络的时候很重要,比如 nn.Module, nn.ModuleList, nn.Sequential,这些类我们称之为容器 (containers),因为我们可以添加模块 (module) 到它们之中。这些容器之间很容易混淆,本文中我们主要学习一下 nn.ModuleList 和 nn.Sequential,并判断在什么时候用哪一个比较合适。本文中的例子使用的是 PyTorch 1.0 版本。

1 torch.nn.ModuleList()

简单的说,就是把子模块存储在list中。它类似于list,既可以 append 操作,也可以做 insert 操作,也可以 extend 操作.。但是由于把layers存入Modulelist中后只是完成了存储作用,所以不能直接在forward中直接运行,需要通过索引调出相应的submodule。

torch.nn.ModuleList(modules=None)

不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会注册到整个网络上的,同时 module 的 parameters 也会自动添加到整个网络中。描述看起来很枯燥,我们来看几个例子。

  • 栗子1:第一个网络,使用 nn.ModuleList 来构建一个小型网络,包括3个全连接层:

class net1(nn.Module):
    def __init__(self):
        super(net1, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(2)])
    def forward(self, x):
        for m in self.linears:
            x = m(x)
        return x

net = net1()
print(net)

# net1(
#   (modules): ModuleList(
#     (0): Linear(in_features=10, out_features=10, bias=True)
#     (1): Linear(in_features=10, out_features=10, bias=True)
#   )
# )

for param in net.parameters():
    print(type(param.data), param.size())

#  torch.Size([10, 10])
#  torch.Size([10])
#  torch.Size([10, 10])
#  torch.Size([10])

我们可以看到,这个网络包含两个全连接层,他们的权重 (weithgs) 和偏置 (bias) 都在这个网络之内。

  • 栗子2:第二个网络,使用 Python 自带的 list:
class net2(nn.Module):
    def __init__(self):
        super(net2, self).__init__()
        self.linears = [nn.Linear(10,10) for i in range(2)]
    def forward(self, x):
        for m in self.linears:
            x = m(x)
        return x

net = net2()
print(net)

# net2()
print(list(net.parameters()))
# []

显然,使用 Python 的 list 添加的全连接层和它们的 parameters 并没有自动注册到我们的网络中。当然,我们还是可以使用 forward 来计算输出结果。但是如果用 net2 实例化的网络进行训练的时候,因为这些层的 parameters 不在整个网络之中,所以其网络参数也不会被更新。

到这里,我们大致明白了 nn.ModuleList 是干什么的了:它是一个储存不同 module,并自动将每个 module 的 parameters 添加到网络之中的容器。但是,我们需要注意到,nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言,比如:

class net3(nn.Module):
    def __init__(self):
        super(net3, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])
    def forward(self, x):
        x = self.linears[2](x)
        x = self.linears[0](x)
        x = self.linears[1](x) 
        return x

net = net3()
print(net)

# net3(
#   (linears): ModuleList(
#     (0): Linear(in_features=10, out_features=20, bias=True)
#     (1): Linear(in_features=20, out_features=30, bias=True)
#     (2): Linear(in_features=5, out_features=10, bias=True)
#   )
# )

input = torch.randn(32, 5)
print(net(input).shape)
# torch.Size([32, 30])

根据 net3 的结果,我们可以看出来这个 ModuleList 里面的顺序并不能决定什么,网络的执行顺序是根据 forward 函数来决定的。如果你非要 ModuleList 和 forward 中的顺序不一样, PyTorch 表示它无所谓,但以后 review 你代码的人可能会意见比较大。

我们再考虑另外一种情况,既然这个 ModuleList 可以根据序号来调用,那么一个模块是否可以在 forward 函数中被调用多次呢?答案当然是可以的,但是,被调用多次的模块,是使用同一组 parameters 的,也就是它们的参数是完全一样的,无论你之后怎么更新。例子如下,虽然在 forward 中我们用了 nn.Linear(10,10) 两次,但是它们只有一组参数。这么做有什么用处呢,我目前没有想到…

class net4(nn.Module):
    def __init__(self):
        super(net4, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(5, 10), nn.Linear(10, 10)])
    def forward(self, x):
        x = self.linears[0](x)
        x = self.linears[1](x)
        x = self.linears[1](x)
        return x

net = net4()
print(net)

# net4(
#   (linears): ModuleList(
#     (0): Linear(in_features=5, out_features=10, bias=True)
#     (1): Linear(in_features=10, out_features=10, bias=True)
#   )
# )

for name, param in net.named_parameters():
    print(name, param.size())

# linears.0.weight torch.Size([10, 5])
# linears.0.bias torch.Size([10])
# linears.1.weight torch.Size([10, 10])
# linears.1.bias torch.Size([10])

2 torch.nn.Sequential()

顺序容器.模块将按照顺序存进sequential中,相当于一个包装起来的子模块集,可以在forward中直接运行。不同于 nn.ModuleList,它已经实现的 forward 函数,而且里面的模块是按照顺序进行排列的,所以我们必须确保前一个模块的输出大小和下一个模块的输入大小是一致的。

 torch.nn.Sequential(*args)
  • 栗子:
class net5(nn.Module):
    def __init__(self):
        super(net5, self).__init__()
        self.block = nn.Sequential(nn.Conv2d(1,20,5),
                                    nn.ReLU(),
                                    nn.Conv2d(20,64,5),
                                    nn.ReLU())
    def forward(self, x):
        x = self.block(x)
        return x

net = net5()
print(net)

# net5(
#   (block): Sequential(
#     (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
#     (1): ReLU()
#     (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
#     (3): ReLU()
#   )
# )

model1 和 从类 net5 实例化来的 net 这两个网络是相同的,因为 nn.Sequential 就是一个 nn.Module 的子类,也就是 nn.Module 所有的方法 (method) 它都有。并且直接使用 nn.Sequential 不用写 forward 函数,因为它内部已经帮你写好了。

这时候有同学该说了,既然 nn.Sequential 这么好,我以后都直接用它了。如果你确定 nn.Sequential 里面的顺序是你想要的,而且不需要再添加一些其他处理的函数 (比如 nn.functional 里面的函数,nn 与 nn.functional 有什么区别? ),那么完全可以直接用 nn.Sequential。这么做的代价就是失去了部分灵活性,毕竟不能自己去定制 forward 函数里面的内容了。

一般情况下 nn.Sequential 的用法是来组成卷积块 (block),然后像拼积木一样把不同的 block 拼成整个网络,让代码更简洁,更加结构化。

3 nn.ModuleList 和 nn.Sequential: 到底该用哪个

  • 场景一:有的时候网络中有很多相似或者重复的层,我们一般会考虑用 for 循环来创建它们,比如:
layers = [nn.Linear(10, 10) for i in range(5)]

这个时候,很自然而然地,我们会想到使用 ModuleList,像这样:

class net6(nn.Module):
    def __init__(self):
        super(net6, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)])

    def forward(self, x):
        for layer in self.linears:
            x = layer(x)
        return x

net = net6()
print(net)

# net6(
#   (linears): ModuleList(
#     (0): Linear(in_features=10, out_features=10, bias=True)
#     (1): Linear(in_features=10, out_features=10, bias=True)
#     (2): Linear(in_features=10, out_features=10, bias=True)
#   )
# )

这个是比较一般的方法,但如果不想这么麻烦,我们也可以用 Sequential 来实现,如 net7 所示!注意 * 这个操作符,它可以把一个 list 拆开成一个个独立的元素。所以在 场景一 中,使用 net7 这种方法比较方便和整洁:

class net7(nn.Module):
    def __init__(self):
        super(net7, self).__init__()
        self.linear_list = [nn.Linear(10, 10) for i in range(3)]
        self.linears = nn.Sequential(*self.linears_list)

    def forward(self, x):
        self.x = self.linears(x)
        return x

net = net7()
print(net)

# net7(
#   (linears): Sequential(
#     (0): Linear(in_features=10, out_features=10, bias=True)
#     (1): Linear(in_features=10, out_features=10, bias=True)
#     (2): Linear(in_features=10, out_features=10, bias=True)
#   )
# )
  • 场景二:当我们需要之前层的信息的时候,比如 ResNets 中的 shortcut 结构,或者是像 FCN 中用到的 skip architecture 之类的,当前层的结果需要和之前层中的结果进行融合,一般使用 ModuleList 比较方便,一个非常简单的例子如下:
class net8(nn.Module):
    def __init__(self):
        super(net8, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 20), nn.Linear(20, 30), nn.Linear(30, 50)])
        self.trace = []

    def forward(self, x):
        for layer in self.linears:
            x = layer(x)
            self.trace.append(x)
        return x

net = net8()
input  = torch.randn(32, 10)
output = net(input)
for each in net.trace:
    print(each.shape)

# torch.Size([32, 20])
# torch.Size([32, 30])
# torch.Size([32, 50])

我们使用了一个 trace 的列表来储存网络每层的输出结果,这样如果以后的层要用的话,就可以很方便的调用了。

参考:

  • https://blog.csdn.net/byron123456sfsfsfa/article/details/89930990​​​​​​​

你可能感兴趣的:(修仙之路:pytorch篇)