RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the

pytorch报错:RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

源代码:

# 加载模型
model = torch.load("test_0.pth")
model.eval()
with torch.no_grad():
    output = model(img)

原因:

GPU上面训练的模型在cpu上进行加载和使用

修改:

加载模型时添加 map_location=torch.device(‘cpu’)

# 加载模型
model = torch.load("test_0.pth", map_location=torch.device('cpu'))

你可能感兴趣的:(Error_pytorch,pytorch,深度学习,人工智能)