本篇文章介绍将Tensorflow的Conv2D层、Dense层和BatchNorm层的权重迁移至Pytorch。关于将Pytorch迁移至Tensorflow可参考下面的博文:
Pytorch与Tensorflow权重互转_太阳花的小绿豆的博客-CSDN博客_tensorflow权重转pytorch
Tensorflow的模型中的每一层一般都会有个name来指定该层的名称。获取某一层时可以使用model.get_layer(name)方法得到,最后使用layer.get_weights()获得权重。而Pytorch的模块里面并没有相关变量指定该层名称,我们可以重新封装这些模块,并指定一个变量来存放名字,这样可以按照Tensorflow模型的结构搭建Pytorch模型,并逐层迁移权重。
Tensorflow的数据维度为(B,H,W,C), 而Pytorch的数据维度为(B,C,H,W), 因此二者卷积层的权重矩阵也是不一样的。Pytorch的为(out_channels,in_channels,H,W), Tensorflow的为(H,W,in_channels,out_channels), 因此权重迁移时需要转置权重矩阵。
此外,如果卷积带有bias,layer.get_weights()返回长度为2的列表,第一个元素为权重矩阵,第二个元素为bias.
class Conv2dWithName(nn.Module):
def __init__(self,in_planes, out_planes, kernel_size=3, stride=1,padding=0, groups=1, use_bias=True, dilation=1,name=None):
super(Conv2dWithName, self).__init__()
self.conv2d=nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, groups=groups, bias=use_bias, dilation=dilation)
self.name=name #存储模块名称
self.use_bias=use_bias
def forward(self,x):
return self.conv2d(x)
def set_weight(self,layer):
with torch.no_grad():
print('INFO: init layer %s with tf weights'%self.name)
weights=layer.get_weights()
weight=weights[0]
weight=torch.from_numpy(weight)
weight=weight.permute((3,2,0,1))
self.conv2d.weight.copy_(weight)
if self.use_bias:
bias=weights[1]
bias = torch.from_numpy(bias)
self.conv2d.bias.copy_(bias)
类似的,dense层包含weight,bias两个权重参数。需要注意的是Pytorch的weight维度为(out_dims,in_dims),而Tensorflow正好相反为(in_dims,out_dims)。
class DenseWithName(nn.Module):
def __init__(self,in_dim,out_dim,name=None):
super(DenseWithName, self).__init__()
self.dense=nn.Linear(in_dim,out_dim)
self.name=name
def set_weight(self,layer):
print('INFO: init layer %s with tf weights' % self.name)
with torch.no_grad():
weights = layer.get_weights()
weight = torch.from_numpy(weights[0]).transpose(0, 1)
self.dense.weight.copy_(weight)
bias = weights[1]
bias = torch.from_numpy(bias)
self.dense.bias.copy_(bias)
def forward(self,x):
return self.dense(x)
BatchNorm需要迁移weight、bias、running_mean、running_var四个参数。
class BatchNorm2dWithName(nn.Module):
def __init__(self,n_chaanels,name=None):
super(BatchNorm2dWithName, self).__init__()
self.bn=nn.BatchNorm2d(n_chaanels)
self.name=name
def forward(self,x):
return self.bn(x)
def set_weight(self,layer):
with torch.no_grad():
print('INFO: init layer %s with tf weights' % self.name)
weights=layer.get_weights()
gamma=torch.from_numpy(weights[0])
beta=torch.from_numpy(weights[1])
run_mean=torch.from_numpy(weights[2])
run_var= torch.from_numpy(weights[3])
self.bn.bias.copy_(beta)
self.bn.running_mean.copy_(run_mean)
self.bn.running_var.copy_(run_var)
self.bn.weight.copy_(gamma)
我们可以参照已有的Tensorflow模型结构,利用上述封装好的层来搭建深度模型。迁移权重时可以遍历模型的所有层,逐层迁移权重。
for m in self.modules():#遍历模型的所有模块
if isinstance(m, (Conv2dWithName,BatchNorm2dWithName,DenseWithName)):
layer=tf_model.get_layer(m.name)
m.set_weight(layer)
下面以ResNet50为例测试权重迁移:
def block1(x, filters, kernel_size=3, stride=1,
conv_shortcut=True, name=None):
"""A residual block.
# Arguments
x: input tensor.
filters: integer, filters of the bottleneck layer.
kernel_size: default 3, kernel size of the bottleneck layer.
stride: default 1, stride of the first layer.
conv_shortcut: default True, use convolution shortcut if True,
otherwise identity shortcut.
name: string, block label.
# Returns
Output tensor for the residual block.
"""
bn_axis = 3
if conv_shortcut is True:
shortcut = layers.Conv2D(4 * filters, 1, strides=stride,
name=name + '_0_conv')(x)
shortcut = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_0_bn')(shortcut)
else:
shortcut = x
x = layers.Conv2D(filters, 1, strides=stride, name=name + '_1_conv')(x)
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_1_bn')(x)
x = layers.Activation('relu', name=name + '_1_relu')(x)
x = layers.Conv2D(filters, kernel_size, padding='SAME',
name=name + '_2_conv')(x)
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_2_bn')(x)
x = layers.Activation('relu', name=name + '_2_relu')(x)
x = layers.Conv2D(4 * filters, 1, name=name + '_3_conv')(x)
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_3_bn')(x)
x = layers.Add(name=name + '_add')([shortcut, x])
x = layers.Activation('relu', name=name + '_out')(x)
return x
def stack1(x, filters, blocks, stride1=2, name=None):
"""A set of stacked residual blocks.
# Arguments
x: input tensor.
filters: integer, filters of the bottleneck layer in a block.
blocks: integer, blocks in the stacked blocks.
stride1: default 2, stride of the first layer in the first block.
name: string, stack label.
# Returns
Output tensor for the stacked blocks.
"""
x = block1(x, filters, stride=stride1, name=name + '_block1')
for i in range(2, blocks + 1):
x = block1(x, filters, conv_shortcut=False, name=name + '_block' + str(i))
return x
def ResNet50_TF(inputs,
preact=False,
use_bias=True,
model_name='resnet50'):
bn_axis = 3
x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)), name='conv1_pad')(inputs)
x = layers.Conv2D(64, 7, strides=2, use_bias=use_bias, name='conv1_conv')(x)
if preact is False:
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name='conv1_bn')(x)
x = layers.Activation('relu', name='conv1_relu')(x)
x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name='pool1_pad')(x)
x = layers.MaxPooling2D(3, strides=2, name='pool1_pool')(x)
outputs = []
x = stack1(x, 64, 3, stride1=1, name='conv2')
x = stack1(x, 128, 4, name='conv3')
x = stack1(x, 256, 6, name='conv4')
x = stack1(x, 512, 3, name='conv5')
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(1, activation='linear', name='final_fc')(x)
# Create model.
model = models.Model(inputs, x, name=model_name)
return model
注意上述ResNet50模型并不是一个原始的Resnet,它的输出维度为1。
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1,name=None):
"""3x3 convolution with padding"""
return Conv2dWithName(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=True, dilation=dilation,name=name)
def conv1x1(in_planes, out_planes, stride=1,name=None):
"""1x1 convolution"""
return Conv2dWithName(in_planes, out_planes, kernel_size=1, stride=stride, bias=True,name=name)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None,name=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = BatchNorm2dWithName
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width,stride=stride,name=name+'_1_conv')
self.bn1 = norm_layer(width,name=name+'_1_bn')
self.conv2 = conv3x3(width, width, name=name+'_2_conv')
self.bn2 = norm_layer(width,name=name+'_2_bn')
self.conv3 = conv1x1(width, planes * self.expansion,name=name+'_3_conv')
self.bn3 = norm_layer(planes * self.expansion,name=name+'_3_bn')
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
if not self.downsample is None:
self.downsample[0].name=name+'_0_conv'
self.downsample[1].name = name + '_0_bn'
self.stride = stride
self.name=name
def forward(self, x):
identity = x
out = checkpoint(self.conv1,x)
out = checkpoint(self.bn1,out)
out = self.relu(out)
out = checkpoint(self.conv2,out)
out = checkpoint(self.bn2,out)
out = self.relu(out)
out = checkpoint(self.conv3,out)
out = checkpoint(self.bn3,out)
if self.downsample is not None:
identity = checkpoint(self.downsample,x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, width_per_group=64):
super(ResNet, self).__init__()
norm_layer = BatchNorm2dWithName
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
self.base_width = width_per_group
self.conv1 = Conv2dWithName(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=True,name='conv1_conv')
self.bn1 = norm_layer(self.inplanes,name='conv1_bn')
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0],name='conv2')
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
name='conv3')
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
name='conv4')
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
name='conv5')
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.final_fc=DenseWithName(2048,1,name='final_fc')
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1, name=None):
norm_layer = self._norm_layer
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride=stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(inplanes=self.inplanes, planes=planes, stride=stride, downsample=downsample,
name=name+'_block1'))
self.inplanes = planes * block.expansion
for lyer in range(1, blocks):
layers.append(block(self.inplanes, planes, base_width=self.base_width, dilation=self.dilation,
name=name+'_block%d'%(lyer+1)))
return nn.Sequential(*layers)
def init_from_tf(self,tf_model):
for m in self.modules():
if isinstance(m, (Conv2dWithName,BatchNorm2dWithName,DenseWithName)):
layer=tf_model.get_layer(m.name)
m.set_weight(layer)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = checkpoint(self.conv1,x)
x = checkpoint(self.bn1,x)
x = self.relu(x)
x = F.max_pool2d(x,kernel_size=3, stride=2, padding=1)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x=self.avgpool(x).squeeze(-1).squeeze(-1)
x=self.final_fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
def _resnet(arch, block, layers, **kwargs):
model = ResNet(block, layers, **kwargs)
return model
def resnet50_torch(**kwargs):
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], **kwargs)
input_shape = (None, None, 3)
inputs = Input(shape=input_shape)
res50_tf=ResNet50_TF(inputs)
res50_tf.load_weights('./src/Resnet——weights.h5',by_name=True)
res50_torch=resnet50_torch().float()
res50_torch.init_from_tf(res50_tf)
res50_torch.eval()
img=np.random.rand(1,224,224,3)
img2=torch.from_numpy(img).permute([0,3,1,2]).float()
p_tf=res50_tf.predict(img)
p_torch=res50_torch(img2).data.numpy()
print('tensorflow predict: %f '%p_tf[0])
print('pytorch predict: %f '%p_torch[0])