pytorch报错: Can only calculate the mean of floating types. Got Long instead

小问题不要慌!!!!
运行代码:

import sys
sys.path.append('..')
import torch

def simple_batch_norm_1d(x, gamma, beta):
    eps = 1e-5
    x_mean = torch.mean(x, dim=0, keepdim=True)  # dim=0在每一列上求取均值  保留维度进行 broadcast
    x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

#  5行3列表示三个特征,每个特征上有五个数据点
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
y = y.float()
print('after bn: ')
print(y)

该代码是学习pytorch数据标准化的代码,对一个tensor求一个均值和方差。
报错如下:
pytorch报错: Can only calculate the mean of floating types. Got Long instead_第1张图片
该错误提示也很明显,在求均值的时候数据类型不对,计算得到的是个long型,对其数据类型做个转换即可。
修改如下:

x_mean = torch.mean(x.float(), dim=0, keepdim=True) 

这是运行就没错误啦!!!!

你可能感兴趣的:(pytorch,深度学习,python)