PyTorch - torch.nn.ModuleList
flyfish
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
#from collections import list
from typing import List, Tuple
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
print("i:",i)
print("l:",l)
x = self.linears[i // 2](x) + l(x)#双斜杠(//)即先做除法(/),然后向下取整(floor
return x
net=MyModule()
print(net)
input = torch.randn(10, 10)
print(net(input))
输出 一个索引一个层
MyModule(
(linears): ModuleList(
(0): Linear(in_features=10, out_features=10, bias=True)
(1): Linear(in_features=10, out_features=10, bias=True)
(2): Linear(in_features=10, out_features=10, bias=True)
(3): Linear(in_features=10, out_features=10, bias=True)
(4): Linear(in_features=10, out_features=10, bias=True)
(5): Linear(in_features=10, out_features=10, bias=True)
(6): Linear(in_features=10, out_features=10, bias=True)
(7): Linear(in_features=10, out_features=10, bias=True)
(8): Linear(in_features=10, out_features=10, bias=True)
(9): Linear(in_features=10, out_features=10, bias=True)
)
)
i: 0
l: Linear(in_features=10, out_features=10, bias=True)
i: 1
l: Linear(in_features=10, out_features=10, bias=True)
i: 2
l: Linear(in_features=10, out_features=10, bias=True)
i: 3
l: Linear(in_features=10, out_features=10, bias=True)
i: 4
l: Linear(in_features=10, out_features=10, bias=True)
i: 5
l: Linear(in_features=10, out_features=10, bias=True)
i: 6
l: Linear(in_features=10, out_features=10, bias=True)
i: 7
l: Linear(in_features=10, out_features=10, bias=True)
i: 8
l: Linear(in_features=10, out_features=10, bias=True)
i: 9
l: Linear(in_features=10, out_features=10, bias=True)
另一个例子
#mobilenetv2
GraphPath = namedtuple("GraphPath", ['s0', 'name', 's1'])
source_layer_indexes0 = [
GraphPath(14, 'conv', 3),
19,
]
#mobilenetv1
source_layer_indexes1 = [
12,
14,
]
#squeezenet
source_layer_indexes2 = [
12
]
#vgg
source_layer_indexes3 = [
(23, torch.nn.BatchNorm2d(512)),
19,
]
class SSD(nn.Module):
def __init__(self, source_layer_indexes: List[int]):#: List[int]
super(SSD, self).__init__()
self.source_layer_add_ons = nn.ModuleList([t[1] for t in source_layer_indexes
if isinstance(t, tuple) and not isinstance(t, GraphPath)])
self.source_layer_indexes = source_layer_indexes
print("init0:",source_layer_indexes)
print("init1",self.source_layer_add_ons)
def forward(self, x):
for end_layer_index in self.source_layer_indexes:
print("all:",end_layer_index)
if isinstance(end_layer_index, GraphPath):
path = end_layer_index
print("GraphPath path:",path)
end_layer_index = end_layer_index.s0
print("GraphPath end_layer_index:",end_layer_index)
added_layer = None
elif isinstance(end_layer_index, tuple):
print(end_layer_index)
added_layer = end_layer_index[1]
print("tuple added_layer",added_layer)
end_layer_index = end_layer_index[0]
print("tuple end_layer_index",end_layer_index)
path = None
return x
print("mobilenetv2")
net0=SSD(source_layer_indexes0)
print(net0)
input = torch.randn(2, 3)
print(net0(input))
print("mobilenetv1")
net1=SSD(source_layer_indexes1)
print(net1)
input = torch.randn(2, 3)
print(net1(input))
print("vgg")
net3=SSD(source_layer_indexes3)
print(net3)
input = torch.randn(2, 3)
print(net3(input))
# =============================================================================
# mobilenetv2
# init0: [GraphPath(s0=14, name='conv', s1=3), 19]
# init1 ModuleList()
# SSD(
# (source_layer_add_ons): ModuleList()
# )
# all: GraphPath(s0=14, name='conv', s1=3)
# GraphPath path: GraphPath(s0=14, name='conv', s1=3)
# GraphPath end_layer_index: 14
# all: 19
# tensor([[-1.1113, 0.5133, -0.2450],
# [-0.2633, -0.7399, 0.7204]])
# mobilenetv1
# init0: [12, 14]
# init1 ModuleList()
# SSD(
# (source_layer_add_ons): ModuleList()
# )
# all: 12
# all: 14
# tensor([[ 0.0620, -0.7680, 0.5259],
# [-0.2823, -1.4778, 0.2337]])
# vgg
# init0: [(23, BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), 19]
# init1 ModuleList(
# (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# )
# SSD(
# (source_layer_add_ons): ModuleList(
# (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# )
# )
# all: (23, BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
# (23, BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
# tuple added_layer BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# tuple end_layer_index 23
# all: 19
# tensor([[-1.5166, 0.9455, 0.6392],
# [-0.4090, 1.1226, -0.2256]])
# =============================================================================