官方介绍:torch.mean — PyTorch 1.11.0 documentation
torch.
mean
(input, dim, keepdim=False, *, dtype=None, out=None)
参数
input,要输入的张量
dim,要求均值的维度
keepdim,求完均值之后是否要保留该维度
dtype,数据格式,(输入整数会被识别为long报错)
1、当dim为空时,输出全部值的平均数
2、当dim为常数时,输出延该维度求完平均数之后的张量
这是官方实例
a = torch.randn(4, 4)
tensor([[-0.3841, 0.6320, 0.4254, -0.7384],
[-0.9644, 1.0131, -0.6549, -1.4279],
[-0.2951, -1.3350, -0.7694, 0.5600],
[ 1.0842, -0.9580, 0.3623, 0.2343]])
torch.mean(a, 1)
tensor([-0.0163, -0.5085, -0.4599, 0.1807])
torch.mean(a, 1, True)
tensor([[-0.0163],
[-0.5085],
[-0.4599],
[ 0.1807]])
3.当dim为列表时,延列表内所有维度全部求均值
这种情况官方没有给出示例,以下是我自己尝试的例子。
import torch
a=torch.tensor([
[[1,1,1],
[2,2,2]],
[[3,3,3],
[4,4,4]]
],dtype=float)
b=torch.mean(a,dim=[0,1])
print(b)
结果输出:
tensor([2.5000, 2.5000, 2.5000], dtype=torch.float64)
可见当dim=【0,1】时,该函数延0,1维度求了均值,然后保留了2维度。