写作背景:是因为我在自己做EAST写PVANet的DCN模型的时候,涉及到DCN中offset的初始化需要设置为常数。
这里需要注意的是,一般情况下我们看到的初始化都是下面第一段代码,根据模型的类型进行不同的初始化。
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
但是DCN中offset和普通的nn.Conv2d一样,但是要进行不同的处理,所以在之后,再补上一个根据层名判断的。
for m in self.modules():
if hasattr(m, 'conv3x3_dcn_offset'):
constant_init(m.conv1_3_offset, 0)
能依次逐层由外到里返回模型的各层
def modules(self):
for name,module in self.named_modules():
yield module
拿EAST为例,在merge中注意我写的class inception_block_dcn_v2(nn.Module),
在merge类中写for m in self.modules():的时候inception_block_dcn_v2中的层也会逐层遍历到,
所以在inception_block_dcn_v2中不用写初始化,在merge中写,因为后期会覆盖。
当然简单情况下直接在最后用到的EAST中写就行。复杂情况需要根据层名来初始化的,就需要看好了,从父往下看。
class EAST(nn.Module):
def __init__(self, pretrained=False, inception =True):
super(EAST, self).__init__()
self.extractor = extractor(pretrained)
self.merge = merge(inception)
self.output = output(512)
class merge(nn.Module):#conv1_5 = True表示2的5次,正常操作,false就是最后一层去除掉了
def __init__(self,inception=False, conv1_5 = True, two = 128, thr = 256, four = 512, five = 512):
super(merge, self).__init__()
self.inception = inception
self.conv1_5 = conv1_5
if self.conv1_5:
self.conv1 = nn.Conv2d(four+five, 128, 1)# 512+512
self.bn1 = nn.BatchNorm2d(128)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(128, 128, 3, padding=1)
self.bn2 = nn.BatchNorm2d(128)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(128+thr, 64, 1)#256+128
else:
self.conv3 = nn.Conv2d(four+thr, 64, 1)#256+128
self.bn3 = nn.BatchNorm2d(64)
self.relu3 = nn.ReLU()
self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.relu4 = nn.ReLU()
self.conv5 = nn.Conv2d(64+two, 32, 1)#128+64
self.bn5 = nn.BatchNorm2d(32)
self.relu5 = nn.ReLU()
self.conv6 = nn.Conv2d(32, 32, 3, padding=1)
self.bn6 = nn.BatchNorm2d(32)
self.relu6 = nn.ReLU()
if not self.inception:
self.conv7 = nn.Conv2d(32, 32, 3, padding=1)
self.bn7 = nn.BatchNorm2d(32)
self.relu7 = nn.ReLU()
else:
self.conv_inception = inception_block_dcn_v2(32,32)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
for m in self.modules():
if hasattr(m, 'conv1_3_offset'):
constant_init(m.conv1_3_offset, 0)
if hasattr(m, 'conv2_4_offset'):
constant_init(m.conv2_4_offset, 0)
if hasattr(m, 'conv3_4_offset'):
constant_init(m.conv3_4_offset, 0)
if hasattr(m, 'conv4_4_offset'):
constant_init(m.conv4_4_offset, 0)
class inception_block_dcn_v2(nn.Module):
def __init__(self,inchannels,channels = 32):
super(inception_block_dcn_v2,self).__init__()
self.inchannels = inchannels
self.channels = channels
if self.inchannels!=self.channels:
self.conv1_1 = BNReLUConv2d(inchannels,channels, 1, 1, 0)
self.conv2_1 = BNReLUConv2d(inchannels,channels, 1, 1, 0)
self.conv3_1 = BNReLUConv2d(inchannels,channels, 1, 1, 0)
self.conv4_1 = BNReLUConv2d(inchannels,channels, 1, 1, 0)
self.skip_layer = BNReLUConv2d(inchannels,channels, 1, 1, 0)
self.conv1_2 = BNReLUConv2d(channels, channels, 1, 1, 0)
self.conv1_3_offset = nn.Conv2d(channels,27,3, padding=1)
self.conv1_3 = ModulatedDeformConv(channels,channels,kernel_size=3,stride=1,padding=1,bias=False)#3*3 dcn代
self.bn1_3 = nn.BatchNorm2d(channels)
self.relu1_3 = nn.ReLU()
self.conv2_2 = BNReLUConv2d(channels, channels, kernel_size=(1, 3), stride =1, padding=(0, 1))
self.conv2_3 = BNReLUConv2d(channels, channels, kernel_size=(3, 1), stride =1, padding=(1, 0))
self.conv2_4_offset = nn.Conv2d(channels,27,3, padding=1)
self.conv2_4 = ModulatedDeformConv(channels,channels,kernel_size=3,stride=1,padding=1,bias=False)
self.bn2_4 = nn.BatchNorm2d(channels)
self.relu2_4 = nn.ReLU()
self.conv3_2 = BNReLUConv2d(channels, channels, kernel_size=(1, 5), stride =1, padding=(0, 2))
self.conv3_3 = BNReLUConv2d(channels, channels, kernel_size=(5, 1), stride =1, padding=(2, 0))
self.conv3_4_offset = nn.Conv2d(channels,27,3, padding=1)
self.conv3_4 = ModulatedDeformConv(channels,channels,kernel_size=3,stride=1,padding=1,bias=False)#3*3 dcn代写
self.bn3_4 = nn.BatchNorm2d(channels)
self.relu3_4 = nn.ReLU()
self.conv4_2 = BNReLUConv2d(channels, channels, kernel_size=(1, 7), stride =1, padding=(0, 3))
self.conv4_3 = BNReLUConv2d(channels, channels, kernel_size=(7, 1), stride =1, padding=(3, 0))
self.conv4_4_offset = nn.Conv2d(channels,27,3, padding=1)
self.conv4_4 = ModulatedDeformConv(channels,channels,kernel_size=3,stride=1,padding=1,bias=False)#3*3 dcn代写
self.bn4_4 = nn.BatchNorm2d(channels)
self.relu4_4 = nn.ReLU()
self.conv1_4_concat = BNReLUConv2d(channels*4,channels, 1, 1, 0)#27*3