加载模型时出现‘module‘不存在的问题

module不存在的原因是因为可能预训练模型使用一个显卡训练,而我们自己训练的模型是多卡训练的,这时在加载模型的过程中就会出现module不存在的报错,解决方法直接上代码:

#create model
model = vgg(model_name=vgg16”,num_classes=5).to(device)
# load model weights
weights_path =./vgg16Net.pth

new_state = {}
state_dict = torch.load(weights_path, map_location=device)
for key,value in state_dict.items():
    new_state[key.replace('module. ',' ' )]=value
model.load_state_dict(new_state)
model.eval()
with torch.no_grad():
    #predict class
    output = torch.squeeze(model(img.to(device))).cpu

你可能感兴趣的:(深度学习,机器学习,人工智能)