torch.mean

自己测试的可以:

import torch

x=torch.arange(256).view(4,4,4,4).float()
x_mean=torch.mean(x,axis=[2,3],keepdim=True)

 nu2 =input.pow(2).mean(dim=[2, 3], keepdim=True)

 

这句在老版本中报错:

nu2 = torch.mean(input.pow(2), axis=[2, 3], keepdims=True)

error:

{TypeError}mean() received an invalid combination of arguments - got (Tensor, dim=list, keepdims=bool), but expected one of:
 * (Tensor input)
 * (Tensor input, torch.dtype dtype)
 * (Tensor input, tuple of ints dim, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of ints dim, bool keepdim, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of ints dim, bool keepdim, Tensor out)
原因是keepdims 参数错了,应该是keepdim.

你可能感兴趣的:(torch)