迁移学习加载改进后的模型参数

pytorch基于改进模型使用迁移学习加载参数

  • 前言
  • 正文

前言

由于发paper或多或少的需要一定创新性,本篇文章是基于YOLOv4算法想对模型结构进行改进,这里仅仅是做一个小样。调整主干模型并不对主干模型的输出进行改变,因此只需要将修改后模型的参数初始化。最后正常对模型进行训练即可。
迁移学习加载改进后的模型参数_第1张图片

本文也许并不一定可以帮到你,但是如果你也想使用迁移学习的方式去初始化模型的参数,那么我相信下面的文章或多或少对你有一些启发。
同时很感谢参考文章的两位作者。

正文

今天在考虑如何对YOLOv4模型进行改进如何考虑使用迁移学习的方式为没有变化的结构提供参数,目前自己大概了解到torch.load()方法是将文件加载成字典的格式。下面是Github上面一国内大神(Bubbliiiing)的yolov4代码pytorch版本进行迁移学习加载权重文件的方式。

    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_path, map_location=device)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

但是我对YOLOv4的模型进行改进,于是考虑简单的对网络结构进行改变,但是由于又想借助迁移学习的方式进行参数初始化,借鉴了https://blog.csdn.net/guyuealian/article/details/94181896
这一位大神的代码,实现了迁移学习初始化权重。

def transfer_model(pretrained_file, model):
    pretrained_dict = torch.load(pretrained_file)  # get pretrained dict
    model_dict = model.state_dict()  # get model dict
    # 在合并前(update),需要去除pretrained_dict一些不需要的参数
    pretrained_dict = transfer_state_dict(pretrained_dict, model_dict)
    model_dict.update(pretrained_dict)  # 更新(合并)模型的参数
    model.load_state_dict(model_dict)
    return model


def transfer_state_dict(pretrained_dict, model_dict):
    # state_dict2 = {k: v for k, v in save_model.items() if k in model_dict.keys()}
    state_dict = {}
    for k, v in pretrained_dict.items():
        if k in model_dict.keys():
            if v.shape==model_dict[k].shape:
                state_dict[k] = v
            else:
                print(k,'shape dismatch')
        else:
            print("Missing key(s) in state_dict :{}".format(k))
    return state_dict
model = transfer_model("model_data/yolo4_voc_weights.pth",model)

参考文章
[1]:https://github.com/bubbliiiing/yolov4-pytorch
[2]: https://blog.csdn.net/guyuealian/article/details/94181896

你可能感兴趣的:(深度学习基础,pytorch,迁移学习,python,深度学习)