使用pytorch时报错:
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
其中 Input type 是我们喂进去的数据, weight type 是网络模型,看出前者位于CPU,后者由于model.cuda()已经在GPU上
处理数据:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mus = input_batch["mus"]
mus = mus.cuda()