load model

def load_weights(self, base_file):
        #pretrain_dict = model_zoo.load_url(model_url)
        print('Loading weights into state dict...')
        pretrain_dict = torch.load(base_file, map_location=torch.device('cpu'))
        model_dict = self.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and model_dict[k].size() == v.size()
        }
        model_dict.update(pretrain_dict)
        self.load_state_dict(model_dict)
        print('Finished!')

def load_weight(new_model, model_dir):
    #当前网络权重字典
    orginal_dict = new_model.state_dict()
    #读取的网络权重字典
    weight_dict = torch.load(model_dir, map_location=torch.device('cpu'))
    for key, value in orginal_dict.items():
        for key2,vlaue2 in weight_dict.items():
            if key==key2 and value.size() == vlaue2.size():
                print("key对应且形状相同!")
                orginal_dict[key] = weight_dict[key2]
    new_model.load_state_dict(orginal_dict)
    print('load model Finished!')

按key读取已有的模型参数,继续训练,在网络结构略作修改的时候可以使用

加载模型,简单复现

你可能感兴趣的:(深度学习,人工智能,python)