【代码笔记】RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should b

初学者经常会遇到下图所示的问题:
【代码笔记】RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should b_第1张图片
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

这是说 我们输入的数据类型与网络参数的类型不符。

Input type为torch.cuda.FloatTensor(GPU数据类型)

weight type(即net.parameters)为torch.FloatTensor(CPU数据类型)

解决方法有三种:

方法一:

使用GPU,convert your network to cuda
net = net.cuda()

方法二:

使用GPU
device = torch.device(‘cuda:0’)
net.to(device)

方法三:

使用CPU,就是 call torchsummary.summary with device=‘cpu’
torchsummary.summary(model,device=‘cpu’)

你可能感兴趣的:(机器学习,深度学习,pytorch,机器学习)