如何用Pytorch加载部分权重

在我做实验的过程中,由于卷积神经网络层数的更改,导致原始网络模型的权重加载失败,经过分析,是因为不匹配造成的,如下方式可以解决.

import torch
import models
checkpoint = torch.load("./logs/01origial/model_best.pth")
model = models.__dict__["vgg"](dataset="Beans", depth=16) #提取网络结结构,分别是数据集,网络的深度和每层的输出通道数
model.load_state_dict(checkpoint['state_dict'])


model_10 = models.__dict__["vgg10"](dataset="Beans", depth=10)

model_dict = model.state_dict()
model_10_dict = model_10.state_dict()

pretrained_dict = {k: v for k, v in model_dict.items() if k in model_10_dict.keys()}
model_10_dict.update(pretrained_dict)
model_10.load_state_dict(model_10_dict)

你可能感兴趣的:(软件使用与程序语法)