pytorch中tensor.mean(axis, keepdim)参数理解小实验

虽然没试过其他形式的多维数据,不过想来应该是一样的吧 ~~

1.结论

keepdim=True
运算完之后的维度和原来一样,原来是三维数组现在还是三维数组(不过某一维度变成了1);

keepdim=False
运算完之后一般少一维度,求平均变为1的那一维没有了;

axis=k
按第k维运算,其他维度不遍,第k维变为1。

2.实验

import numpy as np
import torch

x=[
[[1,2,3,4],[5,6,7,8],[9,10,11,12]],
[[13,14,15,16],[17,18,19,20],[21,22,23,24]]
]
x=torch.tensor(x).float()
#
print("shape of x:")  ##[2,3,4]                                                                      
print(x.shape)                                                                                   
#
print("shape of x.mean(axis=0,keepdim=True):")          #[1, 3, 4]
print(x.mean(axis=0,keepdim=True).shape)                       
#
print("shape of x.mean(axis=0,keepdim=False):")         #[3, 4]
print(x.mean(axis=0,keepdim=False).shape)                     
#
print("shape of x.mean(axis=1,keepdim=True):")          #[2, 1, 4]
print(x.mean(axis=1,keepdim=True).shape)                      
#
print("shape of x.mean(axis=1,keepdim=False):")         #[2, 4]
print(x.mean(axis=1,keepdim=False).shape)                    
#
print("shape of x.mean(axis=2,keepdim=True):")          #[2, 3, 1]
print(x.mean(axis=2,keepdim=True).shape)                     
#
print("shape of x.mean(axis=2,keepdim=False):")         #[2, 3]
print(x.mean(axis=2,keepdim=False).shape)                  

你可能感兴趣的:(Python,python)