用self.modules()方法批量初始化模型权重
用self.modules()可以遍历组成网络的所有模块,以及这些模块的后代模块。
Example:
创建一个网络,其中包括一个预先定义的DoubleConv类
class DoubleConv(nn.Module): def __init__(self,in_channels,out_channels): super(DoubleConv,self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels,out_channels,3,1,1,bias=False), nn.BatchNorm2d(out_channels), nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self,x): return self.conv(x) class Normal_Down_Sampling(nn.Module): def __init__(self, in_channels, out_channels): super(Normal_Down_Sampling, self).__init__() self.conv = nn.Sequential( DoubleConv(in_channels,in_channels), nn.Conv2d(in_channels, out_channels, 7, 2), # 7*7 step=2 nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, x): return self.conv(x)Net = Normal_Down_Sampling(3,64)
遍历网络模块:
for i,m in enumerate(Net.modules()): print(i,m)
结果如下:
第一层为网络结构
0 Normal_Down_Sampling(
(conv): Sequential(
(0): DoubleConv(
(conv): Sequential(
(0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(3): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU(inplace=True)
)
)
(1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2))
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
)
第二层为网络中的Sequential块
1 Sequential(
(0): DoubleConv(
(conv): Sequential(
(0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(3): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU(inplace=True)
)
)
(1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2))
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
第三层为网络中Sequential块中的DoubleConv模块
2 DoubleConv(
(conv): Sequential(
(0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(3): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU(inplace=True)
)
)
第四层为DoubleConv中的Sequential块
3 Sequential(
(0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(3): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU(inplace=True)
)
剩下几层就是torch中的基本模块了
4 Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
5 BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
6 Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
7 BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
8 ReLU(inplace=True)
9 Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2))
10 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
11 ReLU(inplace=True)
网络参数初始化:(可以放在网络__init__函数的最后)
for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m,nn.ConvTranspose2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n))
总结:用self.modules()可以遍历到网络的基本模块(torch中的基本模块)从而进行初始化