PyTorch模型参数迁移

模型微调或参数迁移的核心点是参数的遍历与修改。PyTorch非常灵活,坑也比较多,如果发现参数迁移后模型失效,那一定是convert代码出错了,可以通过重新构建原模型,参数重写来验证。也就是把模型A的参数读出来,再传给另一个实例化的模型A,如果新的模型是有效的,那convert代码也就是有效的。
下面的代码是经过验证,肯定有效的转换代码,而且适用于RepVGG之类的等效模块的参数迁移。

import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from models.ddet import load_model
from utils.utils import disable_print, enable_print

cfg = " "
deploy_cfg = " "
weights = " "


def main():
    # 构造模型
    disable_print()
    train_model = load_model(cfg)
    ckpt = torch.load(weights)
    train_model.load_state_dict(ckpt['model'])
    deploy_model = load_model(deploy_cfg)
    enable_print()

    # 遍历参数
    # for name, param in deploy_model.named_parameters():
    #     print(name)
    # for name, param in deploy_model.state_dict().items():
    #     print(name)

    # 取出参数,不能转为numpy,精度会损失!!!但RepVGG的源码里进行了转换 =.=||
    converted_weights = {}
    for name, param in train_model.state_dict().items():
        # 排除无用层
        if "identity" in name or "dense" in name or "1x1" in name:
            continue
        converted_weights[name] = param.data.detach().cpu()

    # 参数等效变换
    for name, module in train_model.named_modules():
        if hasattr(module, 'repvgg_convert'):
            kernel, bias = module.repvgg_convert()
            converted_weights[name + '.rbr_reparam.weight'] = torch.from_numpy(kernel)
            converted_weights[name + '.rbr_reparam.bias'] = torch.from_numpy(bias)
    del train_model

    # 更新参数(无效)
    # for name, param in deploy_model.state_dict().items():
    #     deploy_model.state_dict()[name].data = torch.from_numpy(converted_weights[name])
    # for name, param in test_model.named_parameters():
    #     param.data = converted_weights[name]

    # 更新参数
    model_dict = deploy_model.state_dict()
    model_dict.update(converted_weights)
    deploy_model.load_state_dict(model_dict)

    # 保存模型
    ckpt['model'] = deploy_model.state_dict()
    save_path = weights.split('.')[0] + "_deploy.pt"
    torch.save(ckpt, save_path)


if __name__ == '__main__':
    main()

你可能感兴趣的:(深度学习,PyTorch,参数遍历,参数修改,模型迁移,RepVGG)