Pytorch构建模型的五种手段

最近看pytorch源码,发现模型的命名非常灵活,故整理部分用法用作备忘。pytorch源码地址

Pytorch构建模型的五种手段

使用pytorch可以方便地进行模型搭建,如果只是简单的分类任务,可以直接调用torchvision.models使用pytorch提供的模型。若要自己搭建网络,主要有以下五种方式:

  • 方式一:最快捷 nn.Sequential

import torch
import torch.nn as nn

net1 = nn.Sequential(
    nn.Linear(784, 100),
    nn.ReLU(),
    nn.Linear(100, 10))
print(net1)

这种方式适用于简单顺序组合已有模块,模块的名称name为数字。

Sequential(
  (0): Linear(in_features=784, out_features=100, bias=True)
  (1): ReLU()
  (2): Linear(in_features=100, out_features=10, bias=True)
)
  • 方式二:快捷的同时更好地命名 nn.Sequential + OrderedDict

import torch
import torch.nn as nn
from collections import OrderedDict

net2 = nn.Sequential(OrderedDict([
    ('fc1',  nn.Linear(784, 100)),
    ('relu', nn.ReLU()),
    ('fc2',  nn.Linear(100, 10))]))
print(net2)

效果同上,模块的名称name为自定义的效果。OrderedDict提供一个带顺序的字典。

Sequential(
  (fc1): Linear(in_features=784, out_features=100, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=100, out_features=10, bias=True)
)
  • 方式三: 最常用 Class

import torch
import torch.nn as nn
from collections import OrderedDict

class Net(nn.Module): #
    def __init__(self): #
        super(Net, self).__init__() #
        self.block1 = nn.Sequential(OrderedDict([
            ('fc1',  nn.Linear(784, 100)),
            ('relu', nn.ReLU()),
            ('fc2',  nn.Linear(100, 10))]))
        self.relu = nn.ReLU()
        self.fc = nn.Linear(10, 10)
    def forward(self, X): #
        X = self.fc(self.relu(self.block1(X)))
        return X #

net3 = Net()
print(net3)

最常用的模型定义的方法,加 # 的五行是定义新模型的套路,可以在__init__函数中直接使用方式一和方式二。

Net(
  (block1): Sequential(
    (fc1): Linear(in_features=784, out_features=100, bias=True)
    (relu): ReLU()
    (fc2): Linear(in_features=100, out_features=10, bias=True)
  )
  (relu): ReLU()
  (fc): Linear(in_features=10, out_features=10, bias=True)
)
  • 方式四:更方便的命名 Class + add_module

import torch
import torch.nn as nn
from collections import OrderedDict

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        num_fc = 3
        self.block1 = nn.Sequential(OrderedDict([
            ('fc1',  nn.Linear(784, 100)),
            ('relu', nn.ReLU()),
            ('fc2',  nn.Linear(100, 10))]))
        self.add_module('relu', nn.ReLU())
        self.features = nn.Sequential(
            nn.Linear(10 ,10),
            nn.ReLU())
        for i in range(num_fc):
            self.features.add_module('basic_fc%d'%(i+1), nn.Sequential(
                nn.Linear(10, 10), 
                nn.ReLU()))
        self.classifier = nn.Linear(10, 10)
    def forward(self, X):
        X = self.features(self.relu(self.block1(X)))
        X = self.classifier(X)
        return X

net4 = Net()
print(net4)

这种方式主要是可以通过for循环重复调用现有模块,并进行批次命名,大大简化代码量。

Net(
  (block1): Sequential(
    (fc1): Linear(in_features=784, out_features=100, bias=True)
    (relu): ReLU()
    (fc2): Linear(in_features=100, out_features=10, bias=True)
  )
  (relu): ReLU()
  (features): Sequential(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): ReLU()
    (basic_fc1): Sequential(
      (0): Linear(in_features=10, out_features=10, bias=True)
      (1): ReLU()
    )
    (basic_fc2): Sequential(
      (0): Linear(in_features=10, out_features=10, bias=True)
      (1): ReLU()
    )
    (basic_fc3): Sequential(
      (0): Linear(in_features=10, out_features=10, bias=True)
      (1): ReLU()
    )
  )
  (classifier): Linear(in_features=10, out_features=10, bias=True)
)
  • 方式五: 为重复性模块作准备 def

import torch
import torch.nn as nn

arch = ((3, 784, 1000), (2, 1000, 10))
def block(num_repeat, in_feature, out_feature):
    blk = []
    for i in range(num_repeat):
        if i == 0:
            blk.append(nn.Sequential(nn.Linear(in_feature, out_feature)))
        else:
            blk.append(nn.Sequential(nn.Linear(out_feature, out_feature)))
        blk.append(nn.ReLU())
        return nn.Sequential(*blk)
        
def net(arch, fc_feature, fc_hidden = 4096):
    net = nn.Sequential()
    for i, (num_repeat, in_feature, out_feature) in enumerate(arch):
        net.add_module('block_%d'% (i+1), block(num_repeat, in_feature, out_feature))
    net.add_module('fc1', nn.Linear(fc_feature, fc_hidden))
    net.add_module('relu', nn.ReLU())
    net.add_module('fc2', nn.Linear(fc_hidden, 10))
    return net

net5 = net(arch, 784, 4096)
print(net5)

使用def的方式可以很方便、简单地定义一些在模型构建中需要的基础模块。

Sequential(
  (block_1): Sequential(
    (0): Sequential(
      (0): Linear(in_features=784, out_features=1000, bias=True)
    )
    (1): ReLU()
  )
  (block_2): Sequential(
    (0): Sequential(
      (0): Linear(in_features=1000, out_features=10, bias=True)
    )
    (1): ReLU()
  )
  (fc1): Linear(in_features=784, out_features=4096, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=4096, out_features=10, bias=True)
)

你可能感兴趣的:(pytorch)