import torch
import torch.nn as nn
from collections import OrderedDict
class MLP(nn.Module):
def __init__(self, **kwargs):
super(MLP,self).__init__(**kwargs)
self.hidden = nn.Linear(784, 256)
self.act = nn.ReLU()
self.output = nn.Linear(256, 10)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output
X = torch.rand(2, 784)
net = MLP()
print(net)
net(X)
输出
MLP(
(hidden): Linear(in_features=784, out_features=256, bias=True)
(act): ReLU()
(output): Linear(in_features=256, out_features=10, bias=True)
)
tensor([[-0.4208, -0.1205, -0.1616, -0.0421, -0.0989, -0.1204, -0.0447, 0.0534,
0.3461, 0.2020],
[-0.1718, -0.0578, -0.0979, 0.1144, -0.1164, -0.1234, -0.0635, 0.0248,
0.1519, 0.1203]], grad_fn=
net = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256,10)
)
print(net)
print(net[-1])
Sequential( (0): Linear(in_features=784, out_features=256, bias=True) (1): ReLU() (2): Linear(in_features=256, out_features=10, bias=True) ) Linear(in_features=256, out_features=10, bias=True)
创建指定索引的模型
net = nn.Sequential(
OrderedDict({
"linear1":nn.Linear(784,256),
"ReLU1": nn.ReLU(),
"linear2": nn.Linear(256,10)
})
)
print(net)
print(net.linear2)
Sequential( (linear1): Linear(in_features=784, out_features=256, bias=True) (ReLU1): ReLU() (linear2): Linear(in_features=256, out_features=10, bias=True) ) Linear(in_features=256, out_features=10, bias=True)
还可以嵌套构建:
net = nn.Sequential(
nn.Sequential(OrderedDict({
"linear1":nn.Linear(784,256),
"ReLU1": nn.ReLU(),
"linear2": nn.Linear(256,10)
})),
nn.Sequential(OrderedDict({
"linear1":nn.Linear(10,10),
"ReLU1": nn.ReLU(),
"linear2": nn.Linear(10,10)
}))
)
print(net)
print(net[0].linear2)
Sequential( (0): Sequential( (linear1): Linear(in_features=784, out_features=256, bias=True) (ReLU1): ReLU() (linear2): Linear(in_features=256, out_features=10, bias=True) ) (1): Sequential( (linear1): Linear(in_features=10, out_features=10, bias=True) (ReLU1): ReLU() (linear2): Linear(in_features=10, out_features=10, bias=True) ) ) Linear(in_features=256, out_features=10, bias=True)
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net2 = nn.ModuleList([nn.Linear(256, 10)])
net.extend(net2)
net.append(nn.Linear(10, 10))
net.insert(3,nn.ReLU())
print(net)
ModuleList( (0): Linear(in_features=784, out_features=256, bias=True) (1): ReLU() (2): Linear(in_features=256, out_features=10, bias=True) (3): ReLU() (4): Linear(in_features=10, out_features=10, bias=True) )
ModuleList主要方法:
class ModuleDict(nn.Module):
def __init__(self):
super(ModuleDict, self).__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(2,2)
})
self.activations = nn.ModuleDict({
'relu': nn.ReLU(),
'prelu': nn.PReLU()
})
def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x
net = ModuleDict()
print(net)
fake_img = torch.randn((4, 10, 32, 32))
output = net(fake_img, 'conv', 'relu')
print(output.shape)
output = net(fake_img, 'pool', 'relu')
print(output.shape)
ModuleDict( (choices): ModuleDict( (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)) (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (activations): ModuleDict( (prelu): PReLU(num_parameters=1) (relu): ReLU() ) ) torch.Size([4, 10, 30, 30]) torch.Size([4, 10, 16, 16])