如题所示的错误为input的type时double的而weight的type时float的(weight的type在pytorch中默认为float),因此遇到这种问题一般有两种解决方案
1.将input的type由double更换为float
示例:
input = Variable(data)
input = input.float()
此种方案是将input的double更改为float毫无疑问会对矩阵内部数值产生影响。因此若是你的矩阵数值大到必须使用double那这个方案需要pass。若是你的矩阵数值使用float对矩阵本身没什么影响则可使用这条命令。
2.将nn.module的type由float更换为double
示例:
linear = nn.Linear(2, 2)
print(linear.weight)
linear.double()
print(linear.weight)
Parameter containing:
tensor([[-0.6322, -0.2967],
[-0.5611, 0.0638]], requires_grad=True)
Parameter containing:
tensor([[-0.6322, -0.2967],
[-0.5611, 0.0638]], dtype=torch.float64, requires_grad=True)
conv = nn.Conv2d(2, 64, 3, padding=1).double()
print(conv.weight)
Parameter containing:
tensor([[[[-0.2335, -0.0215, 0.0710],
[-0.0369, 0.2331, -0.1877],
[-0.1738, 0.1258, -0.0202]],
[[-0.2055, 0.0458, -0.1719],
[ 0.1927, 0.1866, -0.0046],
[ 0.1286, 0.1005, -0.0137]]],
[[[ 0.0427, 0.1982, 0.0761],
[-0.2082, -0.1114, 0.0637],
[ 0.1352, 0.0848, 0.1489]],
[[ 0.0228, 0.2221, -0.1518],
[-0.0462, 0.1292, -0.1872],
[-0.0874, 0.0772, 0.1837]]],
[[[ 0.2184, 0.2230, -0.0358],
[-0.2187, 0.1524, 0.0409],
[ 0.2244, 0.1070, -0.0077]],
[[ 0.1362, -0.1353, -0.0493],
[-0.0381, -0.0550, -0.0046],
[ 0.1634, 0.1359, -0.1857]]],
...,
[[[-0.2268, -0.1568, 0.0236],
[ 0.0697, -0.0852, 0.0374],
[-0.0336, -0.0743, -0.0871]],
[[ 0.0218, -0.0523, 0.1762],
[-0.2168, -0.1741, 0.1094],
[ 0.0216, 0.1672, -0.1006]]],
[[[-0.1174, 0.1897, -0.1503],
[-0.1382, -0.1979, 0.1617],
[ 0.2282, 0.2337, 0.1170]],
[[-0.1433, -0.2325, 0.1023],
[ 0.2197, -0.1562, -0.1585],
[-0.0422, 0.0011, 0.0161]]],
[[[ 0.1030, 0.1959, -0.1112],
[-0.0733, -0.0093, 0.1591],
[ 0.1522, 0.0894, 0.1800]],
[[ 0.0505, 0.1076, -0.0392],
[ 0.1726, 0.0952, -0.1742],
[-0.2179, 0.2116, 0.0131]]]], dtype=torch.float64,
requires_grad=True)
如代码所示,将你定义conv层或者maxpooling、liner层加后缀.double()。则自动强制转换浮点参数和缓冲区为double数据类型。此方案适用于input必须为double类型数据。但麻烦点在于需要为定义的每一个隐藏层加.double()。
参考链接:
cnblogs.com/wanghui-garcia/p/11285055.html
Pytorch