pytorch学习系列(5):模型部分参数使用(迁移学习)

有两种方式:
1.

Net.load_state_dict(torch.load(model_path),strict=False)

使用strict参数,如果为True,表明预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度),这里选择为False,则不完全对等,会自动舍去多余的层和其参数。
2.

pretrained_dict=torch.load(model_path)
model_dict=Net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#不必要的键去除掉
model_dict.update(pretrained_dict)#覆盖现有的字典里的条目
Net.load_state_dict(model_dict)
Net.load_state_dict(torch.load(model_path))

你可能感兴趣的:(pytorch,pytorch,模型参数,迁移学习,部分参数)