nn.Sequential、nn.ModuleList、nn.ModuleDict区别及使用技巧

目录

一、区别及联系

二、使用技巧

2.1、nn.Sequential()

2.2、nn.ModuleList()

2.3、nn.ModuleDict() 


一、区别及联系

nn.Sequential、nn.ModuleList、nn.ModuleDict区别及使用技巧_第1张图片

 先通过图片总结了解三个容器方法的主要区别:

  1. nn.Sequential容器自带forward()方法,无需显示调用。nn.ModuList和nn.ModuleDict自身不具有forward()方法。
  2. nn.Sequential内的网络层必须顺序执行,上一层的输出必须与下一层的输入大小一致。
  3. nn.ModuleDict和nn.ModuleList容器内的网络层无需按顺序执行。

二、使用技巧

2.1、nn.Sequential()

可以直接添加网络层、也可以先声明后利用add_module(name:str,module)方法添加网络层,还可以使用OrderDict([*(name:str,module)])函数添加。

net1 = nn.Sequential(
    nn.Conv2d(3,6,kernel_size=5),
    nn.Conv2d(6,10,kernel_size=3),
    nn.BatchNorm2d(10),
    nn.ReLU(),
)

net2 = nn.Sequential()
net2.add_module('conv1',nn.Conv2d(3,6,kernel_size=5))
net2.add_module('conv2',nn.Conv2d(6,10,kernel_size=3))
net2.add_module('bn',nn.BatchNorm2d(10))
net2.add_module('relu',nn.ReLU())

net3 = nn.Sequential(OrderedDict([
    ['conv1',nn.Conv2d(3,6,kernel_size=5)],
    ('conv2',nn.Conv2d(6,10,kernel_size=3))
]))

print('#####################')
print(net1)
print('#####################')
print(net2)
print('#####################')
print(net3)

输出结果为

#####################
Sequential(
  (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (1): Conv2d(6, 10, kernel_size=(3, 3), stride=(1, 1))
  (2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): ReLU()
)
#####################
Sequential(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 10, kernel_size=(3, 3), stride=(1, 1))
  (bn): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
)
#####################
Sequential(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 10, kernel_size=(3, 3), stride=(1, 1))
)

2.2、nn.ModuleList()

nn.ModuleList里面储存了不同 module,并自动将每个 module 的 parameters 添加到网络容器内容(注册),里面的module是按照List的形式顺序存储的,但是在forward中调用的时候可以随意组合。可以任意将 nn.Module 的子类 (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,方法和 Python 自带的 list 一样,也就是说它可以使用 extend,append 等操作。
 

model = nn.ModuleList([
    nn.Conv2d(3, 6, kernel_size=5),
    nn.Conv2d(6, 10, kernel_size=3),
    nn.BatchNorm2d(10),
    nn.ReLU(),
])
model.extend([nn.Linear(10,10) for i in range(5)])
print(model)

 输出结果为:

ModuleList(
  (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (1): Conv2d(6, 10, kernel_size=(3, 3), stride=(1, 1))
  (2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): ReLU()
  (4): Linear(in_features=10, out_features=10, bias=True)
  (5): Linear(in_features=10, out_features=10, bias=True)
  (6): Linear(in_features=10, out_features=10, bias=True)
  (7): Linear(in_features=10, out_features=10, bias=True)
  (8): Linear(in_features=10, out_features=10, bias=True)
)

运行模块可以直接使用列表索引方式或者利用for循环调用,但是顺序不固定

input = torch.randn(1,6,3,3)
out = model[1](input)
print(out.shape)

#view():[1,10,1,1]->[1,10]
out = out.view(out.shape[0],out.shape[1])

out = [model[i](out) for i in range(4,7)]
for o in out:
    print(o.shape)

######################

torch.Size([1, 10, 1, 1])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])

2.3、nn.ModuleDict() 

nn.ModuleDict书写格式也分为两种:一种是nn.ModuleDict( {name:module , name:module ,...} ),另一种是nn.ModuleDict([ [name,module] , [name,module], ... ])

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        # x = self.choices[choice](x)
        # x = self.activations[act](x)
        return x

net = MyNet()
print(net)

输出结果为

MyNet(
  (choices): ModuleDict(
    (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
    (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  )
  (activations): ModuleDict(
    (lrelu): LeakyReLU(negative_slope=0.01)
    (prelu): PReLU(num_parameters=1)
  )
)

三、参考文献

PyTorch中的Sequential、ModuleList和ModuleDict用法总结_非晚非晚的博客-CSDN博客

nn.Sequential与nn.ModuleList_HySmiley的博客-CSDN博客

pytorch模型容器Containers nn.ModuleDict、nn.moduleList、nn.Sequential_nn.moduledict()_发呆的比目鱼的博客-CSDN博客

你可能感兴趣的:(pytorch基础变成,笔记)