解决 Pytorch RuntimeError: expected type torch.cuda.FloatTensor but got torch.FloatTensor

在使用Pytorch框架训练模型时,抛出RuntimeError: expected type torch.cuda.FloatTensor but got torch.FloatTensor。

产生原因及分析

待训练网络在GPU中运算,但有部分数据未进入GPU。
基于Pytorch框架使用GPU进行训练时,输入数据(样本、标记)、网络结构均会在GPU中进行计算,例如:

……
device = 'cuda:0'
model = models.resnet18(pretrained=True).to(device)
……
inputs = inputs.to(device)
labels = labels.to(device)
……

但如果在网络计算过程中,有新加入Tensor但没有明确指定其在GPU中运算(默认是在CPU中),则会抛出上述异常。

解决办法(单GPU)

通过上述分析可知,在网络中引入新的Tensor时,显式指定其运行设备为GPU,可解决上述问题,例如:

……
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype).to(device)
……

在单GPU时,该方法能够解决上述问题;但在多GPU情况下,仍会抛出不在同一个GPU上计算的异常。

解决办法(多GPU)

对于多GPU,网络中新增Tensor一般会与输入(inputs)进行计算;因此获取inputs所在的GPU设备,将新Tensor的计算设备设置为与之相同,问题得解,例如:

……
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype).to(inputs.device)
……

或者

……
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
……

你可能感兴趣的:(机器学习)