MindSpore-TOOD模型权重迁移推理对齐实录

准备工作

环境:
wsl2 Ubuntu 20.04
mindspore 2.0.0
python 3.8
pytorch 2.0.1 cpu

基于自己编写的mindspore TOOD项目和MMDetection实现的pytorch权重来做迁移,

  • TOOD论文pytorch mmdetection实现
    tood_r50_fpn_1x_coco权重
    论文中的代码也是用mmdetection实现的
  • TOOD mmdetection实现
    观察上面两个实现的配置文件,区别只是分类损失用的不同,我们先对照TOOD mmdetection实现。
  • MindSpore TOOD项目链接
    该代码基于FCOS mindspore实现的,对网络命名进行了优化,更靠近官方的pytorch风格

基于MindSpore实现TOOD forward 结构

先搭模型,结构就是resnet50+fpn+toodhead。除了模型结构,还要注意head以及fpn部分的权值初始化要与mmdetection中的实现对齐,这个在后续训练时会有影响

  • 两种框架下pad的区别需要注意,区别见MindSpore官方的迁移指南 ,我尽量使用显式表达,防止出错
  • resent50 backbone在训练时加载预训练权重进行初始化
  • mmdetection中FPN部分的初始化为xavier初始化,我在mindspore中采用更好的kaiming初始化
  • head部分卷积和一般性的偏置使用normal初始化以及zeros初始化
  • head部分的分类分支偏置采用的prob初始化
  • 其他部分(BN,GN)的初始化两个框架相同

权重转换

迁移其实就是在做权重的键值映射对齐,有了FCOS的迁移经验,且对网络模型部分做了命名优化,做这个会快很多。

可参考的经验:

  • FCOS权重迁移经验
  • https://gitee.com/lirongxi4/pt2ms_convert
    一个迁移脚本,通用性一般
  • MindSpore官方的迁移指南

打印两种框架的权重的名称及shape进行比对,
利用文本对比网站进行对比:
MindSpore-TOOD模型权重迁移推理对齐实录_第1张图片
根据shape可以看到顺序完全对齐了,注意scale在pt中是一个浮点数,而在ms中是一个1x1的tensor。FPN实现的运算顺序也在代码中专门调试过,只需完成名称转换即可。

虽然可以根据顺序直接转换,但为了稳定性,还是用字典映射的方法,总结的名称转换方式如下(pytorch的名称改为mindspore的):

def tood_pth2ckpt():
    ms_ckpt = ms.load_checkpoint('tood_ms.ckpt')  # mindspore FCOS保存的随机权重
    pth = torch.load("/mnt/f/pretrain_weight/tood_r50_fpn_1x_coco.pth", map_location=torch.device('cpu'))  # pytorch FCOS权重
    match_pt_kv = {}  # 匹配到的pt权重的name及value的字典
    match_pt_kv_mslist = []  # 匹配到的pt权重的name及value的字典, mindspore加载权重需求的格式
    not_match_pt_kv = {}  # 未匹配到的pt权重的name及value
    matched_ms_k = []  # 被匹配到的ms权重名称

    '''一般性的转换规则'''
    pt2ms = {'backbone': 'tood_body.backbone',  # backbone部分
             'neck': 'tood_body.fpn',
             'bbox_head': 'tood_body.head',
             'downsample': 'down_sample_layer',
             }

    '''conv层的转换规则, 一致,可忽略'''
    pt2ms_conv = {
        "weight": "weight",
        "bias": "bias",
    }

    '''downsample层的转换规则, 有卷积层和bn层, 分别为0,1命名,在torch中weight重复'''
    pt2ms_down = {
        "0.weight": "0.weight",
        "1.weight": "1.gamma",

        "1.bias": "1.beta",
        "running_mean": "moving_mean",
        "running_var": "moving_variance",
    }

    '''BN层的转换规则'''
    pt2ms_bn = {
        "running_mean": "moving_mean",
        "running_var": "moving_variance",
        "weight": "gamma",
        "bias": "beta",
    }

    '''GN层的转换规则'''
    pt2ms_gn = {
        "weight": "gamma",
        "bias": "beta",
    }

    for i, v in pth['state_dict'].items():
        pt_name = copy.deepcopy(i)
        pt_value = copy.deepcopy(v)
        '''一般性的处理'''
        for k, v in pt2ms.items():
            if k in pt_name:
                pt_name = pt_name.replace(k, v)

        '''conv层的转换规则, 一致,可忽略'''

        '''FPN部分特别处理'''
        if 'fpn' in pt_name:
            pt_name = pt_name.replace('.conv', '')

        '''下采样层特别处理'''
        if 'down' in pt_name:
            for k, v in pt2ms_down.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)

        '''BN层处理'''
        if 'bn' in pt_name:
            for k, v in pt2ms_bn.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)

        '''GN层处理'''
        if 'gn' in pt_name:
            for k, v in pt2ms_gn.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)

        '''改名成功,匹配到ms中的权重了,记录'''
        if pt_name in ms_ckpt.keys():
            if 'scale' in pt_name:
                pt_value = torch.tensor([pt_value])
            assert pt_value.shape == ms_ckpt[pt_name].shape
            match_pt_kv[pt_name] = pt_value
            match_pt_kv_mslist.append({'name': pt_name, 'data': ms.Tensor(pt_value.numpy(), ms_ckpt[pt_name].dtype)})
            matched_ms_k.append(pt_name)
        else:
            not_match_pt_kv[i + '   ' + pt_name] = pt_value

    '''打印未匹配的pt权重名称'''
    print('\n\n-----------------------------未匹配的pt权重名称----------------------------')
    print('----------原名称--------                        ----------转换后名称---------')
    for j, v in not_match_pt_kv.items():
        print(j, np.array(v.shape))

    '''打印未被匹配到的ms权重名称'''
    print('\n\n---------------------------未被匹配到的ms权重名称----------------------------')
    for j, v in ms_ckpt.items():
        if j not in matched_ms_k:
            print(j, np.array(v.shape))
    print('end')
    return match_pt_kv_mslist

输出:

-----------------------------未匹配的pt权重名称----------------------------
----------原名称--------                        ----------转换后名称---------
backbone.layer4.1.bn3.num_batches_tracked   tood_body.backbone.layer4.1.bn3.num_batches_tracked []
backbone.layer4.2.bn1.num_batches_tracked   tood_body.backbone.layer4.2.bn1.num_batches_tracked []
backbone.layer4.2.bn2.num_batches_tracked   tood_body.backbone.layer4.2.bn2.num_batches_tracked []
backbone.layer4.2.bn3.num_batches_tracked   tood_body.backbone.layer4.2.bn3.num_batches_tracked []
......

---------------------------未被匹配到的ms权重名称----------------------------
end

剩下一些bn层的num_batches_tracked状态,不需要管

接下来进行输出对齐,推理到需要padding的卷积时发现了一些问题,
mindspore中

nn.Conv2d(64, 64, kernel_size=3, stride=1,
                     padding=1, pad_mode='pad', has_bias=False)

不等价于pytorch的

nn.Conv2d(64, 64, kernel_size=3, stride=1,
                     padding=1)

查阅资料按道理应该等价的啊,结果不等价
发现是跟ms中这样等价的, 先pad,再valid卷积:

pad1 = ms.nn.Pad(((0,0),(0,0),(1,1),(1,1)))
conv2 = ms.nn.Conv2d(64, 64, kernel_size=3, stride=1,
                      pad_mode='valid')

不解。。。

未完待续。。。

你可能感兴趣的:(目标检测,mindspore,深度学习,机器学习,人工智能)