pytorch构建模型整体上看主要有两种方式,一种是继承torch.nn.Module
类,另一种就是直接使用继承自该类的子类:Sequential
,ModuleList
,ModuleDict
。
torch.nn.Module
类Module
类是nn
模块里提供的一个模型构造类,是所有神经网络模块的基类,我们可以继承它来定义我们想要的模型。MLP
类重载了Module
类的__init__
函数和forward
函数。它们分别用于创建模型参数和定义前向计算。前向计算也即正向传播。import torch
from torch import nn
class MLP(nn.Module):
# 声明带有模型参数的层,这里声明了两个全连接层
def __init__(self, **kwargs):
# 调用MLP父类Module的构造函数来进行必要的初始化。
# 这样在构造实例时还可以指定其他函数参数
super(MLP, self).__init__(**kwargs)
self.hidden = nn.Linear(784, 256) # 隐藏层
self.act = nn.ReLU()
self.output = nn.Linear(256, 10) # 输出层
# 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
torch.nn.Module
类的子类:Sequential
,ModuleList
,ModuleDict
1. Sequential
子类
net = nn.Sequential(
nn.Linear(num_inputs, num_hiddens),
nn.ReLU(),
)
net.add_module("linear2",nn.Linear(num_hiddens,num_outs))
net.add_module("relu2",nn.ReLU())
或者也可以这样:
from collections import OrderedDict
net = nn.Sequential(
OrderedDict(
[
("linear1", nn.Linear(num_inputs, num_hiddens)),
("relu1", nn.ReLU()),
]
))
net.add_module("linear2",nn.Linear(num_hiddens,num_outs))
net.add_module("relu2",nn.ReLU())
2. ModuleList
子类
有以下属性: append(module)
, extend(module)
,insert(index, module)
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
需要注意的是:ModuleList
仅仅是一个储存各种模块的列表,这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度匹配),而且没有实现forward
功能需要自己实现,所以上面执行net(torch.zeros(1, 784))
会报NotImplementedError
;而Sequential
内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部forward
功能已经实现。
ModuleList
的出现只是让网络定义前向传播时更加灵活,见下面官网的例子。
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
3. ModuleDict
子类
有以下属性:clear()
,items()
,keys()
,pop(key)
,update(modules)
net = nn.ModuleDict({
'linear': nn.Linear(784, 256),
'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
和ModuleList
一样,ModuleDict
实例仅仅是存放了一些模块的字典,并没有定义forward
函数需要自己定义。
torch.nn.Module
和torch.nn.Sequential
import torch
import torch.nn.functional as F
from collections import OrderedDict
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = torch.nn.Sequential(
OrderedDict(
[
("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
("relu1", torch.nn.ReLU()),
("pool", torch.nn.MaxPool2d(2))
]
))
self.dense = torch.nn.Sequential()
self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
self.dense.add_module("relu2",torch.nn.ReLU())
self.dense.add_module("dense2",torch.nn.Linear(128, 10))
)
def forward(self, x):
conv_out = self.conv1(x)
res = conv_out.view(conv_out.size(0), -1)
out = self.dense(res)
return out
通过这个例子只想告诉大家,其实创建模型是很灵活的。