Tensorflow权重迁移至Pytorch

Tensorflow权重迁移至Pytorch

本篇文章介绍将Tensorflow的Conv2D层、Dense层和BatchNorm层的权重迁移至Pytorch。关于将Pytorch迁移至Tensorflow可参考下面的博文:

Pytorch与Tensorflow权重互转_太阳花的小绿豆的博客-CSDN博客_tensorflow权重转pytorch

基本思路

Tensorflow的模型中的每一层一般都会有个name来指定该层的名称。获取某一层时可以使用model.get_layer(name)方法得到,最后使用layer.get_weights()获得权重。而Pytorch的模块里面并没有相关变量指定该层名称,我们可以重新封装这些模块,并指定一个变量来存放名字,这样可以按照Tensorflow模型的结构搭建Pytorch模型,并逐层迁移权重。

文章目录

  • Tensorflow权重迁移至Pytorch
    • 基本思路
    • Conv2D层
    • Dense层
    • BatchNorm层
    • 逐层迁移权重
        • Tensorflow模型
        • 对应的Pytorch 模型
        • 测试实验代码

Conv2D层

Tensorflow的数据维度为(B,H,W,C), 而Pytorch的数据维度为(B,C,H,W), 因此二者卷积层的权重矩阵也是不一样的。Pytorch的为(out_channels,in_channels,H,W), Tensorflow的为(H,W,in_channels,out_channels), 因此权重迁移时需要转置权重矩阵。

此外,如果卷积带有bias,layer.get_weights()返回长度为2的列表,第一个元素为权重矩阵,第二个元素为bias.

class Conv2dWithName(nn.Module):
    def __init__(self,in_planes, out_planes, kernel_size=3, stride=1,padding=0, groups=1, use_bias=True, dilation=1,name=None):
        super(Conv2dWithName, self).__init__()
        self.conv2d=nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                     padding=padding, groups=groups, bias=use_bias, dilation=dilation)
        self.name=name #存储模块名称
        self.use_bias=use_bias
    def forward(self,x):
        return self.conv2d(x)
    
    def set_weight(self,layer):
        with torch.no_grad():
            print('INFO: init layer %s with tf weights'%self.name)
            weights=layer.get_weights()
            weight=weights[0]
            weight=torch.from_numpy(weight)
            weight=weight.permute((3,2,0,1))
            self.conv2d.weight.copy_(weight)
            if self.use_bias:
                bias=weights[1]
                bias = torch.from_numpy(bias)
                self.conv2d.bias.copy_(bias)

Dense层

类似的,dense层包含weight,bias两个权重参数。需要注意的是Pytorch的weight维度为(out_dims,in_dims),而Tensorflow正好相反为(in_dims,out_dims)。

class DenseWithName(nn.Module):
    def __init__(self,in_dim,out_dim,name=None):
        super(DenseWithName, self).__init__()
        self.dense=nn.Linear(in_dim,out_dim)
        self.name=name
    def set_weight(self,layer):
        print('INFO: init layer %s with tf weights' % self.name)
        with torch.no_grad():
            weights = layer.get_weights()
            weight = torch.from_numpy(weights[0]).transpose(0, 1)
            self.dense.weight.copy_(weight)
            bias = weights[1]
            bias = torch.from_numpy(bias)
            self.dense.bias.copy_(bias)
    def forward(self,x):
        return self.dense(x)

BatchNorm层

BatchNorm需要迁移weight、bias、running_mean、running_var四个参数。

class BatchNorm2dWithName(nn.Module):
    def __init__(self,n_chaanels,name=None):
        super(BatchNorm2dWithName, self).__init__()
        self.bn=nn.BatchNorm2d(n_chaanels)
        self.name=name
    def forward(self,x):
        return self.bn(x)

    def set_weight(self,layer):
        with torch.no_grad():
            print('INFO: init layer %s with tf weights' % self.name)
            weights=layer.get_weights()
            gamma=torch.from_numpy(weights[0])
            beta=torch.from_numpy(weights[1])
            run_mean=torch.from_numpy(weights[2])
            run_var= torch.from_numpy(weights[3])
            self.bn.bias.copy_(beta)
            self.bn.running_mean.copy_(run_mean)
            self.bn.running_var.copy_(run_var)
            self.bn.weight.copy_(gamma)

逐层迁移权重

我们可以参照已有的Tensorflow模型结构,利用上述封装好的层来搭建深度模型。迁移权重时可以遍历模型的所有层,逐层迁移权重。

for m in self.modules():#遍历模型的所有模块
    if isinstance(m, (Conv2dWithName,BatchNorm2dWithName,DenseWithName)):
        layer=tf_model.get_layer(m.name)
        m.set_weight(layer)

下面以ResNet50为例测试权重迁移:

Tensorflow模型

def block1(x, filters, kernel_size=3, stride=1,
           conv_shortcut=True, name=None):
    """A residual block.

    # Arguments
        x: input tensor.
        filters: integer, filters of the bottleneck layer.
        kernel_size: default 3, kernel size of the bottleneck layer.
        stride: default 1, stride of the first layer.
        conv_shortcut: default True, use convolution shortcut if True,
            otherwise identity shortcut.
        name: string, block label.

    # Returns
        Output tensor for the residual block.
    """
    bn_axis = 3

    if conv_shortcut is True:
        shortcut = layers.Conv2D(4 * filters, 1, strides=stride,
                                 name=name + '_0_conv')(x)
        shortcut = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                                             name=name + '_0_bn')(shortcut)
    else:
        shortcut = x

    x = layers.Conv2D(filters, 1, strides=stride, name=name + '_1_conv')(x)
    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                                  name=name + '_1_bn')(x)
    x = layers.Activation('relu', name=name + '_1_relu')(x)

    x = layers.Conv2D(filters, kernel_size, padding='SAME',
                      name=name + '_2_conv')(x)
    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                                  name=name + '_2_bn')(x)
    x = layers.Activation('relu', name=name + '_2_relu')(x)

    x = layers.Conv2D(4 * filters, 1, name=name + '_3_conv')(x)
    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                                  name=name + '_3_bn')(x)

    x = layers.Add(name=name + '_add')([shortcut, x])
    x = layers.Activation('relu', name=name + '_out')(x)
    return x


def stack1(x, filters, blocks, stride1=2, name=None):
    """A set of stacked residual blocks.

    # Arguments
        x: input tensor.
        filters: integer, filters of the bottleneck layer in a block.
        blocks: integer, blocks in the stacked blocks.
        stride1: default 2, stride of the first layer in the first block.
        name: string, stack label.

    # Returns
        Output tensor for the stacked blocks.
    """
    x = block1(x, filters, stride=stride1, name=name + '_block1')
    for i in range(2, blocks + 1):
        x = block1(x, filters, conv_shortcut=False, name=name + '_block' + str(i))
    return x


def ResNet50_TF(inputs,
             preact=False,
             use_bias=True,
             model_name='resnet50'):

    bn_axis = 3

    x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)), name='conv1_pad')(inputs)
    x = layers.Conv2D(64, 7, strides=2, use_bias=use_bias, name='conv1_conv')(x)

    if preact is False:
        x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                                      name='conv1_bn')(x)
        x = layers.Activation('relu', name='conv1_relu')(x)

    x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name='pool1_pad')(x)
    x = layers.MaxPooling2D(3, strides=2, name='pool1_pool')(x)

    outputs = []
    x = stack1(x, 64, 3, stride1=1, name='conv2')

    x = stack1(x, 128, 4, name='conv3')

    x = stack1(x, 256, 6, name='conv4')

    x = stack1(x, 512, 3, name='conv5')

    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)

    x = layers.Dense(1, activation='linear', name='final_fc')(x)

    # Create model.
    model = models.Model(inputs, x, name=model_name)

    return model

注意上述ResNet50模型并不是一个原始的Resnet,它的输出维度为1。

对应的Pytorch 模型

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1,name=None):
    """3x3 convolution with padding"""
    return Conv2dWithName(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=True, dilation=dilation,name=name)


def conv1x1(in_planes, out_planes, stride=1,name=None):
    """1x1 convolution"""
    return Conv2dWithName(in_planes, out_planes, kernel_size=1, stride=stride, bias=True,name=name)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):

    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None,name=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = BatchNorm2dWithName
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width,stride=stride,name=name+'_1_conv')
        self.bn1 = norm_layer(width,name=name+'_1_bn')
        self.conv2 = conv3x3(width, width, name=name+'_2_conv')
        self.bn2 = norm_layer(width,name=name+'_2_bn')
        self.conv3 = conv1x1(width, planes * self.expansion,name=name+'_3_conv')
        self.bn3 = norm_layer(planes * self.expansion,name=name+'_3_bn')
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        if not self.downsample is None:
            self.downsample[0].name=name+'_0_conv'
            self.downsample[1].name = name + '_0_bn'
        self.stride = stride
        self.name=name

    def forward(self, x):
        identity = x

        out = checkpoint(self.conv1,x)
        out = checkpoint(self.bn1,out)
        out = self.relu(out)

        out = checkpoint(self.conv2,out)
        out = checkpoint(self.bn2,out)
        out = self.relu(out)

        out = checkpoint(self.conv3,out)
        out = checkpoint(self.bn3,out)

        if self.downsample is not None:
            identity = checkpoint(self.downsample,x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, width_per_group=64):
        super(ResNet, self).__init__()

        norm_layer = BatchNorm2dWithName
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1

        self.base_width = width_per_group
        self.conv1 = Conv2dWithName(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=True,name='conv1_conv')
        self.bn1 = norm_layer(self.inplanes,name='conv1_bn')
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0],name='conv2')
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       name='conv3')
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                      name='conv4')
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       name='conv5')
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.final_fc=DenseWithName(2048,1,name='final_fc')

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


    def _make_layer(self, block, planes, blocks, stride=1, name=None):

        norm_layer = self._norm_layer
        downsample = nn.Sequential(
            conv1x1(self.inplanes, planes * block.expansion, stride=stride),
            norm_layer(planes * block.expansion),
        )
        layers = []
        layers.append(block(inplanes=self.inplanes, planes=planes, stride=stride, downsample=downsample,
                            name=name+'_block1'))
        self.inplanes = planes * block.expansion
        for lyer in range(1, blocks):
            layers.append(block(self.inplanes, planes, base_width=self.base_width, dilation=self.dilation,
                               name=name+'_block%d'%(lyer+1)))

        return nn.Sequential(*layers)

    def init_from_tf(self,tf_model):
        for m in self.modules():
            if isinstance(m, (Conv2dWithName,BatchNorm2dWithName,DenseWithName)):
                layer=tf_model.get_layer(m.name)
                m.set_weight(layer)


    def _forward_impl(self, x):

        # See note [TorchScript super()]
        x = checkpoint(self.conv1,x)
        x = checkpoint(self.bn1,x)
        x = self.relu(x)
        x = F.max_pool2d(x,kernel_size=3, stride=2, padding=1)

        x = self.layer1(x)

        x = self.layer2(x)

        x = self.layer3(x)

        x = self.layer4(x)

        x=self.avgpool(x).squeeze(-1).squeeze(-1)

        x=self.final_fc(x)
        return x

    def forward(self, x):
        return self._forward_impl(x)


def _resnet(arch, block, layers,  **kwargs):
    model = ResNet(block, layers, **kwargs)
    return model

def resnet50_torch(**kwargs):
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3],  **kwargs)

测试实验代码

input_shape = (None, None, 3)
inputs = Input(shape=input_shape)
res50_tf=ResNet50_TF(inputs)
res50_tf.load_weights('./src/Resnet——weights.h5',by_name=True)
res50_torch=resnet50_torch().float()
res50_torch.init_from_tf(res50_tf)
res50_torch.eval()
img=np.random.rand(1,224,224,3)
img2=torch.from_numpy(img).permute([0,3,1,2]).float()
p_tf=res50_tf.predict(img)
p_torch=res50_torch(img2).data.numpy()
print('tensorflow predict: %f '%p_tf[0])
print('pytorch predict: %f '%p_torch[0])

输出结果
Tensorflow权重迁移至Pytorch_第1张图片

你可能感兴趣的:(tensorflow,pytorch,深度学习)