虽然没试过其他形式的多维数据,不过想来应该是一样的吧 ~~
keepdim=True
运算完之后的维度和原来一样,原来是三维数组现在还是三维数组(不过某一维度变成了1);
keepdim=False
运算完之后一般少一维度,求平均变为1的那一维没有了;
axis=k
按第k维运算,其他维度不遍,第k维变为1。
import numpy as np
import torch
x=[
[[1,2,3,4],[5,6,7,8],[9,10,11,12]],
[[13,14,15,16],[17,18,19,20],[21,22,23,24]]
]
x=torch.tensor(x).float()
#
print("shape of x:") ##[2,3,4]
print(x.shape)
#
print("shape of x.mean(axis=0,keepdim=True):") #[1, 3, 4]
print(x.mean(axis=0,keepdim=True).shape)
#
print("shape of x.mean(axis=0,keepdim=False):") #[3, 4]
print(x.mean(axis=0,keepdim=False).shape)
#
print("shape of x.mean(axis=1,keepdim=True):") #[2, 1, 4]
print(x.mean(axis=1,keepdim=True).shape)
#
print("shape of x.mean(axis=1,keepdim=False):") #[2, 4]
print(x.mean(axis=1,keepdim=False).shape)
#
print("shape of x.mean(axis=2,keepdim=True):") #[2, 3, 1]
print(x.mean(axis=2,keepdim=True).shape)
#
print("shape of x.mean(axis=2,keepdim=False):") #[2, 3]
print(x.mean(axis=2,keepdim=False).shape)