PyTorch - torch.nn.ModuleList

PyTorch - torch.nn.ModuleList

flyfish

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
#from collections import list
from typing import List, Tuple
class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            print("i:",i)
            print("l:",l)
            x = self.linears[i // 2](x) + l(x)#双斜杠(//)即先做除法(/),然后向下取整(floor
        return x
    
    
net=MyModule()
print(net)

input = torch.randn(10, 10)

print(net(input))

输出 一个索引一个层

MyModule(
(linears): ModuleList(
(0): Linear(in_features=10, out_features=10, bias=True)
(1): Linear(in_features=10, out_features=10, bias=True)
(2): Linear(in_features=10, out_features=10, bias=True)
(3): Linear(in_features=10, out_features=10, bias=True)
(4): Linear(in_features=10, out_features=10, bias=True)
(5): Linear(in_features=10, out_features=10, bias=True)
(6): Linear(in_features=10, out_features=10, bias=True)
(7): Linear(in_features=10, out_features=10, bias=True)
(8): Linear(in_features=10, out_features=10, bias=True)
(9): Linear(in_features=10, out_features=10, bias=True)
)
)
i: 0
l: Linear(in_features=10, out_features=10, bias=True)
i: 1
l: Linear(in_features=10, out_features=10, bias=True)
i: 2
l: Linear(in_features=10, out_features=10, bias=True)
i: 3
l: Linear(in_features=10, out_features=10, bias=True)
i: 4
l: Linear(in_features=10, out_features=10, bias=True)
i: 5
l: Linear(in_features=10, out_features=10, bias=True)
i: 6
l: Linear(in_features=10, out_features=10, bias=True)
i: 7
l: Linear(in_features=10, out_features=10, bias=True)
i: 8
l: Linear(in_features=10, out_features=10, bias=True)
i: 9
l: Linear(in_features=10, out_features=10, bias=True)

另一个例子

#mobilenetv2
GraphPath = namedtuple("GraphPath", ['s0', 'name', 's1'])  

source_layer_indexes0 = [
        GraphPath(14, 'conv', 3),
        19,
    ]

#mobilenetv1
source_layer_indexes1 = [
        12,
        14,
    ]

#squeezenet
source_layer_indexes2 = [
        12
]

#vgg
source_layer_indexes3 = [
        (23, torch.nn.BatchNorm2d(512)),
        19,
    ]



class SSD(nn.Module):
    def __init__(self,  source_layer_indexes: List[int]):#: List[int]
        super(SSD, self).__init__()
        self.source_layer_add_ons = nn.ModuleList([t[1] for t in source_layer_indexes
                                                   if isinstance(t, tuple) and not isinstance(t, GraphPath)])
    
        self.source_layer_indexes = source_layer_indexes
        print("init0:",source_layer_indexes)     
        print("init1",self.source_layer_add_ons)
                                                    
        
    def forward(self, x):
        for end_layer_index in self.source_layer_indexes:
            print("all:",end_layer_index)
            if isinstance(end_layer_index, GraphPath):
                path = end_layer_index
                print("GraphPath path:",path)
                end_layer_index = end_layer_index.s0
                print("GraphPath end_layer_index:",end_layer_index)
                added_layer = None
                
            elif isinstance(end_layer_index, tuple):
                print(end_layer_index)
                added_layer = end_layer_index[1]
                print("tuple added_layer",added_layer)
                end_layer_index = end_layer_index[0]
                print("tuple end_layer_index",end_layer_index)
                path = None
                
        return x
print("mobilenetv2")      
net0=SSD(source_layer_indexes0)
print(net0)
input = torch.randn(2, 3)
print(net0(input))

print("mobilenetv1")    
net1=SSD(source_layer_indexes1)
print(net1)
input = torch.randn(2, 3)
print(net1(input))

print("vgg")  
net3=SSD(source_layer_indexes3)
print(net3)
input = torch.randn(2, 3)
print(net3(input))



# =============================================================================
# mobilenetv2
# init0: [GraphPath(s0=14, name='conv', s1=3), 19]
# init1 ModuleList()
# SSD(
#   (source_layer_add_ons): ModuleList()
# )
# all: GraphPath(s0=14, name='conv', s1=3)
# GraphPath path: GraphPath(s0=14, name='conv', s1=3)
# GraphPath end_layer_index: 14
# all: 19
# tensor([[-1.1113,  0.5133, -0.2450],
#         [-0.2633, -0.7399,  0.7204]])
# mobilenetv1
# init0: [12, 14]
# init1 ModuleList()
# SSD(
#   (source_layer_add_ons): ModuleList()
# )
# all: 12
# all: 14
# tensor([[ 0.0620, -0.7680,  0.5259],
#         [-0.2823, -1.4778,  0.2337]])
# vgg
# init0: [(23, BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), 19]
# init1 ModuleList(
#   (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# )
# SSD(
#   (source_layer_add_ons): ModuleList(
#     (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#   )
# )
# all: (23, BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
# (23, BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
# tuple added_layer BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# tuple end_layer_index 23
# all: 19
# tensor([[-1.5166,  0.9455,  0.6392],
#         [-0.4090,  1.1226, -0.2256]])
# =============================================================================

你可能感兴趣的:(深度学习)