报错Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

使用pytorch时报错:
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

其中 Input type 是我们喂进去的数据, weight type 是网络模型,看出前者位于CPU,后者由于model.cuda()已经在GPU上

处理数据:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mus = input_batch["mus"]
mus = mus.cuda()

你可能感兴趣的:(pytorch,python)