【PyTorch】模型构造:Module 和其子类Sequential、ModuleList、ModuleDict

0. 导入用到的库

import torch
import torch.nn as nn
from collections import OrderedDict

1. 继承Module构造和访问模型

Module 类是 torch.nn 模块里提供的⼀个模型构造类,是所有神经⽹络模块的基类,我们可以继承它来定义 我们想要的模型。
只需重写 forward方法
 
class MLP(nn.Module):
    def __init__(self, **kwargs):
        super(MLP,self).__init__(**kwargs)
        self.hidden = nn.Linear(784, 256)
        self.act = nn.ReLU()
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output


查看模型
X = torch.rand(2, 784)
net = MLP()
print(net)
net(X)

输出


MLP(
  (hidden): Linear(in_features=784, out_features=256, bias=True)
  (act): ReLU()
  (output): Linear(in_features=256, out_features=10, bias=True)
)
tensor([[-0.4208, -0.1205, -0.1616, -0.0421, -0.0989, -0.1204, -0.0447,  0.0534,
          0.3461,  0.2020],
        [-0.1718, -0.0578, -0.0979,  0.1144, -0.1164, -0.1234, -0.0635,  0.0248,
          0.1519,  0.1203]], grad_fn=)


2. 使用Sequential构造和访问模型

net = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256,10)
)
print(net)
print(net[-1])

Sequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)
Linear(in_features=256, out_features=10, bias=True)

创建指定索引的模型

net = nn.Sequential(
        OrderedDict({
            "linear1":nn.Linear(784,256),
            "ReLU1": nn.ReLU(),
            "linear2": nn.Linear(256,10)
        })
    )
print(net)
print(net.linear2)

Sequential(
  (linear1): Linear(in_features=784, out_features=256, bias=True)
  (ReLU1): ReLU()
  (linear2): Linear(in_features=256, out_features=10, bias=True)
)
Linear(in_features=256, out_features=10, bias=True)

还可以嵌套构建:

net = nn.Sequential(
        nn.Sequential(OrderedDict({
            "linear1":nn.Linear(784,256),
            "ReLU1": nn.ReLU(),
            "linear2": nn.Linear(256,10)
        })),
         nn.Sequential(OrderedDict({
            "linear1":nn.Linear(10,10),
            "ReLU1": nn.ReLU(),
            "linear2": nn.Linear(10,10)
        }))
    )
print(net)
print(net[0].linear2)

Sequential(
  (0): Sequential(
    (linear1): Linear(in_features=784, out_features=256, bias=True)
    (ReLU1): ReLU()
    (linear2): Linear(in_features=256, out_features=10, bias=True)
  )
  (1): Sequential(
    (linear1): Linear(in_features=10, out_features=10, bias=True)
    (ReLU1): ReLU()
    (linear2): Linear(in_features=10, out_features=10, bias=True)
  )
)
Linear(in_features=256, out_features=10, bias=True)

3. 使用ModuleList构造和访问模型

net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net2 = nn.ModuleList([nn.Linear(256, 10)])
net.extend(net2)
net.append(nn.Linear(10, 10))
net.insert(3,nn.ReLU())

print(net)

ModuleList(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
  (3): ReLU()
  (4): Linear(in_features=10, out_features=10, bias=True)
)

ModuleList主要方法:

  • append():在ModuleList后面添加网络层
  • extend():拼接两个ModuleList
  • insert():指定在ModuleList中位置插入网络层

4. 使用ModuleDict构造和访问模型

class ModuleDict(nn.Module):
    def __init__(self):
        super(ModuleDict, self).__init__()
        self.choices = nn.ModuleDict({
            'conv': nn.Conv2d(10, 10, 3),
            'pool': nn.MaxPool2d(2,2)
        })
        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'prelu': nn.PReLU()
        })

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

net = ModuleDict()
print(net)
fake_img = torch.randn((4, 10, 32, 32))
output = net(fake_img, 'conv', 'relu')
print(output.shape)

output = net(fake_img, 'pool', 'relu')
print(output.shape)

ModuleDict(
  (choices): ModuleDict(
    (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (activations): ModuleDict(
    (prelu): PReLU(num_parameters=1)
    (relu): ReLU()
  )
)
torch.Size([4, 10, 30, 30])
torch.Size([4, 10, 16, 16])

 

 

你可能感兴趣的:(PyTorch,python,开发,pytorch,神经网络,深度学习,机器学习)