PyTorch:expected scalar type Float but found Double

原因分析:

代码中网络参数类型不统一

解决方案:

在最前面加

import torch
torch.set_default_tensor_type(torch.DoubleTensor)

或者在网络初始化之后加=

net = net.double()

转载:https://blog.csdn.net/sazass/article/details/109725458

你可能感兴趣的:(机器学习)