错误:RuntimeError: result type Float can’t be cast to the desired output type long int
翻译:RuntimeError: 结果类型Float无法强制转换为所需的输出类型long int
原因:tensor数据类型仅仅在使用tensor.int(),tensor.long()等强制转换操作时才会实现类型转换,在in-place操作中不会实现类型转换
解决方案:将数据类型为float的tensor使用tensor.long()转换为int型参与运算
案例:在使用tensor的运算时出现报错
#分别定义两个dtype分别为int64和float64的tensor
long_tensor=torch.tensor([1,1,1],dtype=torch.int64)
float_tensor=torch.tensor([1,1,1],dtype=torch.float64)
#然后我们执行
long_tensor.add_(float_tensor)
#add_()作用与+=类似,所以也可以执行
long_tensor+=float_tensor
运行上述代码出现报错
RuntimeError: result type Double can't be cast to the desired output type Long
所以我们为float_tensor做类型转换操作
float_tensor=float_tensor.long()
之后再执行in-place操作便可以正常执行
原因分析:
数据无法转换类型,那么只可能是两个部分无法转换:
A.两种数据类型的数据在运算时无法进行数据类型转换从而进行运算
B.做in-place操作时无法进行类型转换
要判断是哪个部分出现错误,我们只要将in-place操作上两种类型数据位置转换即可,即
float_tensor.add_(long_tensor)
float_tensor+=long_tensor
此时我们发现不报错,可以正常执行,那么显而易见,是做in-place操作时出现了错误,而不是运算过程出现问题,因为此时的in-place操作不涉及类型转换但涉及计算!
那么我们进一步思考,是只有long->float转换时出现问题吗?两个int或者float进行计算时,是否存在类型转换问题?
我们做以下实验
float32_tensor=torch.tensor([1,1,1],dtype=torch.float32)
float64_tensor=torch.tensor([1,1,1],dtype=torch.float64)
float32_tensor+=float64_tensor
print(float32_tensor.dtype)
我们让float64对float32进行in-place操作,看是否会出现报错,若不会,一同检查其数据类型
torch.float32
结果喜人,并没有出现报错,但其数据类型依然没有出现我们期待的‘向上兼容’,即float32的tensor并没有因为in-place操作而转换为float64型
结论:
tensor数据类型仅仅在使用tensor.int(),tensor.long()等强制转换时才会实现类型转换,在in-place操作中不会实现类型转换