model_state_dict网络部分参数的更新(非严格模式加载model)

基础:字典的update方法

marks = {'Physics':67, 'Maths':87}
internal_marks = {'Practical':48}
marks.update(internal_marks)

print(marks)
# Output: {'Physics': 67, 'Maths': 87, 'Practical': 48}

不一致时

model.load_state_dict(model_state_dict, strict=False)

保存模型的key多

model_state_dict, optimizer_state_dict = loader(os.path.join(checkpoint_floder, config["model_loadname"]))

model_dict = model.state_dict()
model_state_dict = {k: v for k, v in model_state_dict.items() if k in model_dict}
model_dict.update(model_state_dict)

model.load_state_dict(model_dict)

保存模型的key少

# 先获取保存的key
model_state_dict, optimizer_state_dict = loader(os.path.join(checkpoint_floder, config["model_loadname"]))

# 获取当前模型的key
model_dict = model.state_dict()
model_dict.update(model_state_dict)# 用保存的key部分更新当前模型的key,得到model_dict

model.load_state_dict(model_dict)# 将更新的model_dict加载给当前的模型

你可能感兴趣的:(深度学习,python,开发语言)