PyTorch - torch.nn.Sequential
flyfish
官网的示例
# Example of using Sequential
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# Example of using Sequential with OrderedDict
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
运行起来
示例代码1
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
output=Net()
print(output)
# =============================================================================
# Net(
# (model): Sequential(
# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
# (1): ReLU()
# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
# (3): ReLU()
# )
# )
# =============================================================================
示例代码2
# 每一层都有名字
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
output=Net()
print(output)
# =============================================================================
# Net(
# (model): Sequential(
# (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
# (relu1): ReLU()
# (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
# (relu2): ReLU()
# )
# )
# =============================================================================
示例代码3
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.model=nn.Sequential()
self.model.add_module('conv1', nn.Conv2d(1,20,5))
self.model.add_module('relu1', nn.ReLU())
self.model.add_module('conv2', nn.Conv2d(20,64,5))
self.model.add_module('relu2', nn.ReLU())
output=Net()
print(output)
# =============================================================================
# Net(
# (model): Sequential(
# (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
# (relu1): ReLU()
# (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
# (relu2): ReLU()
# )
# )
#
# =============================================================================