PyTorch复杂模型初始化

写作背景:是因为我在自己做EAST写PVANet的DCN模型的时候,涉及到DCN中offset的初始化需要设置为常数。

常见初始化形式

这里需要注意的是,一般情况下我们看到的初始化都是下面第一段代码,根据模型的类型进行不同的初始化

for m in self.modules():
	if isinstance(m, nn.Conv2d):
		nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
		if m.bias is not None:
			nn.init.constant_(m.bias, 0)
	elif isinstance(m, nn.BatchNorm2d):
		nn.init.constant_(m.weight, 1)
		nn.init.constant_(m.bias, 0)

但是DCN中offset和普通的nn.Conv2d一样,但是要进行不同的处理,所以在之后,再补上一个根据层名判断的。

for m in self.modules():
	if hasattr(m, 'conv3x3_dcn_offset'):
		constant_init(m.conv1_3_offset, 0)

self.modules()本质

能依次逐层由外到里返回模型的各层

def modules(self):
    for name,module in self.named_modules():
        yield module

重点:在哪里写初始化

拿EAST为例,在merge中注意我写的class inception_block_dcn_v2(nn.Module),

在merge类中写for m in self.modules():的时候inception_block_dcn_v2中的层也会逐层遍历到,

所以在inception_block_dcn_v2中不用写初始化,在merge中写,因为后期会覆盖。

当然简单情况下直接在最后用到的EAST中写就行。复杂情况需要根据层名来初始化的,就需要看好了,从父往下看。

class EAST(nn.Module):
	def __init__(self, pretrained=False, inception =True):
		super(EAST, self).__init__()
		self.extractor = extractor(pretrained)
		self.merge     = merge(inception)
		self.output    = output(512)
class merge(nn.Module):#conv1_5 = True表示2的5次,正常操作,false就是最后一层去除掉了
	def __init__(self,inception=False, conv1_5 = True, two = 128, thr = 256, four = 512, five = 512):

		super(merge, self).__init__()
		self.inception = inception
		self.conv1_5 = conv1_5
		if self.conv1_5:
			self.conv1 = nn.Conv2d(four+five, 128, 1)# 512+512
			self.bn1 = nn.BatchNorm2d(128)
			self.relu1 = nn.ReLU()
			self.conv2 = nn.Conv2d(128, 128, 3, padding=1)
			self.bn2 = nn.BatchNorm2d(128)
			self.relu2 = nn.ReLU()

			self.conv3 = nn.Conv2d(128+thr, 64, 1)#256+128
		else:
			self.conv3 = nn.Conv2d(four+thr, 64, 1)#256+128

		self.bn3 = nn.BatchNorm2d(64)
		self.relu3 = nn.ReLU()
		self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
		self.bn4 = nn.BatchNorm2d(64)
		self.relu4 = nn.ReLU()

		self.conv5 = nn.Conv2d(64+two, 32, 1)#128+64
		self.bn5 = nn.BatchNorm2d(32)
		self.relu5 = nn.ReLU()
		self.conv6 = nn.Conv2d(32, 32, 3, padding=1)
		self.bn6 = nn.BatchNorm2d(32)
		self.relu6 = nn.ReLU()
		if not self.inception:
			self.conv7 = nn.Conv2d(32, 32, 3, padding=1)
			self.bn7 = nn.BatchNorm2d(32)
			self.relu7 = nn.ReLU()
		else:
			self.conv_inception = inception_block_dcn_v2(32,32)

		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
				if m.bias is not None:
					nn.init.constant_(m.bias, 0)
			elif isinstance(m, nn.BatchNorm2d):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)
    
		for m in self.modules():
			if hasattr(m, 'conv1_3_offset'):
				constant_init(m.conv1_3_offset, 0)
			if hasattr(m, 'conv2_4_offset'):
				constant_init(m.conv2_4_offset, 0)
			if hasattr(m, 'conv3_4_offset'):
				constant_init(m.conv3_4_offset, 0)
			if hasattr(m, 'conv4_4_offset'):
				constant_init(m.conv4_4_offset, 0)
class inception_block_dcn_v2(nn.Module):
	def __init__(self,inchannels,channels = 32):
		super(inception_block_dcn_v2,self).__init__()
		self.inchannels = inchannels
		self.channels = channels
		if self.inchannels!=self.channels:
			self.conv1_1 = BNReLUConv2d(inchannels,channels, 1, 1, 0)
			self.conv2_1 = BNReLUConv2d(inchannels,channels, 1, 1, 0)
			self.conv3_1 = BNReLUConv2d(inchannels,channels, 1, 1, 0)
			self.conv4_1 = BNReLUConv2d(inchannels,channels, 1, 1, 0)
			self.skip_layer = BNReLUConv2d(inchannels,channels, 1, 1, 0)
		self.conv1_2 = BNReLUConv2d(channels, channels, 1, 1, 0)
		self.conv1_3_offset = nn.Conv2d(channels,27,3, padding=1)
		self.conv1_3 = ModulatedDeformConv(channels,channels,kernel_size=3,stride=1,padding=1,bias=False)#3*3 dcn代
		self.bn1_3 = nn.BatchNorm2d(channels)
		self.relu1_3 = nn.ReLU()

		self.conv2_2 = BNReLUConv2d(channels, channels, kernel_size=(1, 3), stride =1, padding=(0, 1))
		self.conv2_3 = BNReLUConv2d(channels, channels, kernel_size=(3, 1), stride =1, padding=(1, 0))
		self.conv2_4_offset = nn.Conv2d(channels,27,3, padding=1)
		self.conv2_4 = ModulatedDeformConv(channels,channels,kernel_size=3,stride=1,padding=1,bias=False)
		self.bn2_4 = nn.BatchNorm2d(channels)
		self.relu2_4 = nn.ReLU()

		self.conv3_2 = BNReLUConv2d(channels, channels, kernel_size=(1, 5), stride =1, padding=(0, 2))
		self.conv3_3 = BNReLUConv2d(channels, channels, kernel_size=(5, 1), stride =1, padding=(2, 0))
		self.conv3_4_offset = nn.Conv2d(channels,27,3, padding=1)
		self.conv3_4 = ModulatedDeformConv(channels,channels,kernel_size=3,stride=1,padding=1,bias=False)#3*3 dcn代写
		self.bn3_4 = nn.BatchNorm2d(channels)
		self.relu3_4 = nn.ReLU()

		self.conv4_2 = BNReLUConv2d(channels, channels, kernel_size=(1, 7), stride =1, padding=(0, 3))
		self.conv4_3 = BNReLUConv2d(channels, channels, kernel_size=(7, 1), stride =1, padding=(3, 0))
		self.conv4_4_offset = nn.Conv2d(channels,27,3, padding=1)
		self.conv4_4 = ModulatedDeformConv(channels,channels,kernel_size=3,stride=1,padding=1,bias=False)#3*3 dcn代写
		self.bn4_4 = nn.BatchNorm2d(channels)
		self.relu4_4 = nn.ReLU()

		self.conv1_4_concat = BNReLUConv2d(channels*4,channels, 1, 1, 0)#27*3

 

你可能感兴趣的:(pytorch)