PyTorch 的容器 Container 包括 ( nn.ModuleList, nn.ModuleDict, nn.Sequential )

PyTorch 几种构建网络的容器类

class MyModel1(nn.Module):
    
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.flatten = nn.Flatten()
        self.linears = [nn.Linear(1024, 512), nn.Linear(512, 128), nn.Linear(128, 10)]
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.flatten(x)
        x = F.relu(self.linears[0](x))
        x = F.relu(self.linears[1](x))
        x = self.linears[2](x)
        return x

在继承 nn.Modules 来构造自定义 module 时,会通过父类的__setattr__ 来识别 module 的成员属性是否为 nn.Modules 的子类来将其注册为模块的子模块。但是,如果子模块没用直接赋值给对象属性,而是通过 Python 的 list 封装的,那么子模块就不会被注册。

>>> net = MyModel1()
>>> net
MyModel1(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
)

1. nn.ModuleList

为了解决上述的问题,PyTorch 中定义了 nn.ModuleList 对象。nn.ModuleList 也继承 nn.Module 类,并且实现了和 Python 内置list 相同的功能。

class ModuleList(nn.Module):
    # ModuleList的特殊方法。
    def __init__(self, modules):
        r"""
        parameter:
            modules: iterable, optional, 要添加的modules的迭代器
        """
        pass
    
    def append(self, module):
        r"""
        将module添加到列表的最后。
        parameter:
            modules: iterable, optional, 要添加的module
        """
        pass
    
    def extend(self, modules):
        r"""
        将从modules的迭代器添加modules到列表的最后。
        parameter:
            modules: iterable, optional, 要添加的modules的迭代器
        """
        pass
    
    def insert(self, index, module):
        r"""
        parameter:
            index: int, 插入的位置索引
            modules: iterable, optional, 要添加的module
        """
        pass

nn.ModuleList 可以被识别添加为定义模块的子模块,extend、append 和 insert 方法和 list 中的用法类似。

class MyModel2(nn.Module):
    
    def __init__(self):
        super(MyModel2, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.flatten = nn.Flatten()
        self.linears = nn.ModuleList([nn.Linear(1024, 512)])
        self.linears.extend([nn.Linear(512, 128)])
        self.linears.append(nn.Linear(128, 10))
                                     
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.flatten(x)
        x = F.relu(self.linears[0](x))
        x = F.relu(self.linears[1](x))
        x = self.linears[2](x)
        return x
    
net = MyModel2()
net

输出:

MyModel2(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linears): ModuleList(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=128, bias=True)
    (2): Linear(in_features=128, out_features=10, bias=True)
  )
)

nn.ParameterList 使用和 nn.ModuleList 类似,可以将 nn.Parameter 添加为模块的参数。

2. nn.ModuleDict

class ModuleDict(nn.Module):
    
    def __init__(self, modules):
        r"""
        parameter:
            modules: iterable, optional, string:module的字典,或(string, module)的迭代器
        """
        pass
    
    def clear(self):
        r"""移除ModuleDict中的全部item"""
        pass
    
    def items(self):
        r"""返回一个key/value对的迭代器"""
        pass
    
    def keys(self):
        r"""返回一个key的迭代器"""
        pass
    def values(self):
        r"""返回一个ModuleDict值的迭代器"""
        pass
    
    def pop(self, key):
        r"""从ModuleDict中移除key,并返回对应的module"""
        pass
    
    def update(self, modules):
        r"""
        更新dict
        parameters:
            modules: iterable, string 和 module的映射类型
        """
        pass

定义自定义网络:

class MyModel3(nn.Module):
    
    def __init__(self):
        super(MyModel3, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.flatten = nn.Flatten()
        self.linears = nn.ModuleDict({'linear1': nn.Linear(1024, 512),
                                      'linear2': nn.Linear(512, 128)})
        self.linears.update({'linear3': nn.Linear(128, 10)})
                                     
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.flatten(x)
        x = F.relu(self.linears['linear1'](x))
        x = F.relu(self.linears['linear2'](x))
        x = self.linears['linear3'](x)
        return x
    
net = MyModel3()
print(net)

输出:

MyModel2(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linears): ModuleDict(
    (linear1): Linear(in_features=1024, out_features=512, bias=True)
    (linear2): Linear(in_features=512, out_features=128, bias=True)
    (linear3): Linear(in_features=128, out_features=10, bias=True)
  )
)

3. nn.Sequential

nn.Sequential 通用继承了 nn.Module 类,与Module不同的是,Sequential 已经默认定义了forward函数,按照顺序依次输入输出。
例1:

class MyModel4(nn.Module):
    
    def __init__(self):
        super(MyModel4, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.flatten = nn.Flatten()
        self.linears = nn.Sequential(nn.Linear(1024, 512),
                                     nn.Linear(512, 128),
                                     nn.Linear(128, 10))
                                     
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.flatten(x)
        x = F.relu(self.linears['linear1'](x))
        x = F.relu(self.linears['linear2'](x))
        x = self.linears['linear3'](x)
        return x
    
net = MyModel4()
print(net)

输出:

MyModel4(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linears): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=128, bias=True)
    (2): Linear(in_features=128, out_features=10, bias=True)
  )
)

例2:

class MyModel5(nn.Module):
    
    def __init__(self):
        super(MyModel5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.flatten = nn.Flatten()
        self.linears = nn.Sequential(OrderedDict(
                                        [('1', nn.Linear(1024, 512)),
                                         ('2', nn.Linear(512, 128)),
                                         ('3', nn.Linear(128, 10))])
                                     )
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.flatten(x)
        x = F.relu(self.linears['linear1'](x))
        x = F.relu(self.linears['linear2'](x))
        x = self.linears['linear3'](x)
        return x
    
net = MyModel5()
print(net)

输出:

MyModel5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linears): Sequential(
    (1): Linear(in_features=1024, out_features=512, bias=True)
    (2): Linear(in_features=512, out_features=128, bias=True)
    (3): Linear(in_features=128, out_features=10, bias=True)
  )
)

你可能感兴趣的:(PyTorch使用,深度学习,机器学习,数据分析,pytorch)