RuntimeError: Can only calculate the mean of floating types. Got Long instead. 解决

最近在使用pytorch搭建一个网络的时候遇到一个问题,使用torch.mean计算行或者列的平均值的时候,由于之前的tensor中全是int型,程序出现了标题中的报错,因此解决方法如下:将需要计算的tensor使用.float()函数转换成float型,示例代码如下:

import torch
a=torch.tensor([[1,2,3],[2,3,4]])
print(torch.mean(a.float(),dim=0))

注意:dim=0按行求平均,dim=1按列求平均,视情况使用

你可能感兴趣的:(机器学习,python,pytorch)