pytorch 加载 state_dict

pytorch 加载state dict

  • pytorch 加载 state_dict
    • 什么是state_dict?
    • state_dict有什么用?
    • 如何加载state_dict

pytorch 加载 state_dict

该文章主要介绍Pytorch的state_dict是什么,以及如何加载state_dict,主要用于个人学习记录,如有错误,欢迎支持。

什么是state_dict?

在Pytorch中,一个模型的可学习参数(learnable parameters),就是权重以及偏置可以通过以下代码获取

model.parameters()

state_dict 就是一个python字典对象,该对象里面储存的是每一层所对应的可学习参数(learnable parameters) ,可通过以下代码获得

model.state_dict()

pytorch 官网定义: pytorch.

state_dict有什么用?

上面说了state_dict其实就是将模型的每一层的可学习参数保存在字典里,并且,该字典可以用于保存和加载模型,所以它有以下用法:

  1. 在模型的训练过程中保存checkpoint,用于继续训练或推理。
  2. 用于迁移学习,通过下载别人的预训练模型/state_dict,加载到自己的模型上。
  3. 类似于第二点,也是我个人的情况,修改了模型的部分结构,但仍然想给没修改的layer加载对应的state_dict。

如何加载state_dict

下面直接上代码,再解释:

    def load_network(net, load_path, strict=False, param_key='params'):
        '''
        param net: 你的模型
        param load_path: 你想要加载的state_dict的路径
        '''
        # 拿到模型的state_dict
        net_dict = net.state_dict()
        # 拿到预训练的state_dict
        load_net = torch.load(load_path)
 
        # 根据size判断是否加载权重
        for k, v in load_net.items():
            if v.size() == net_dict[k].size():
                net_dict[k] = v
        net.load_state_dict(net_dict, strict=strict)
        return net

稍微解释一下,它的逻辑主要是判断你的模型的state_dict(net_dict)和预训练权重(load_net)他们对应的layer,所对应的tensor的size是否一致,一致则导入,不一致则不到入。

参考:pix2pixhd

你可能感兴趣的:(pytorch,python,深度学习,pytorch,深度学习,python)