pytorch的许多函数,例如torch.cat()、torch.max()、torch.mul()等,都包含了dim参数。关于dim这个函数,我想许多人跟我一样,一知半解,比较模糊,下面我就把自己关于dim的一点看法记录下来,供大家参考,欢迎各位大佬批评指正。
1.关于dim的取值
这一点很多博客都没有说清,大家可能大概知道dim可以取0,1,-1这些值。那么dim还能取其他值吗?具体的规则是什么呢?不同值代表的不同含义又是什么?
A.关于dim 取值范围:假设有一个n维的tensor,那么它的dim可以取值的范围是[-n,n-1]。举个例子:a=torch.randn(2,3,4),很明显,a是一个3维张量,那么它的dim可以取的值有-3,-2,-1,0,1,2。
B.关于不同dim取值代表的含义:还是上面的例子,对于a来说,dim=0意思就是对a的第1个维度进行操作(也就是2所在的那个维度);dim=1意思就是对a的第2个维度进行操作(也就是3所在的那个维度);dim=2意思就是对a的第3个维度进行操作(也就是4所在的那个维度)。
那么对于dim为负数时,代表的意思又是什么呢?对于dim=-1,代表的是张量最里面的那个维度,对于上面例子的a来说,就是代表了第3个维度(也就是4所在的那个维度);对于dim=-2,代表的是a的第2个维度(也就是3所在的那个维度);对于dim=-3,代表的是a的第1个维度(也就是2所在的那个维度)。
综上,以a为例,dim=0和dim=-3代表的含义一样;dim=1和dim=-2代表的含义一样;dim=2和dim=-1代表的含义一样。
举一反三,对于二维tensor,dim=1和-1一样;dim=0和-2一样。对于一维张量:dim=0和-1结果一样。
2.关于dim的具体操作
知道了dim的取值范围和含义,但是对于指定了dim后如何进行相关操作,应该得到怎么样的结果,我也比较迷糊。所以也查了资料,总结如下:
以三维张量的torch.sum()操作为例:
y = torch.tensor([
[
[1, 2, 3],
[4, 5, 6]
],
[
[1, 2, 3],
[4, 5, 6]
],
[
[1, 2, 3],
[4, 5, 6]
]
])
对于y这个三维张量,dim=0是第一维,它可以拆分成3个3个二维张量,那么之后的求和操作就是将这三个二维张量加起来:
得到的结果就是:
>> torch.sum(y, dim=0)
tensor([[ 3, 6, 9],
[12, 15, 18]])
PS: 此处可以发现:对于size为(3,2,3)的三维张量y来说,如果指定dim=0操作,那么得到的结果一定是size为(2,3)的张量,相当于把第1维砍掉了。那么由此可知,如果指定dim=1操作,那么得到的结果一定是size为(3,3)的张量,相当于把第2维砍掉了,从下面的结果也可以验证。以此类推。
同理,如执行dim=1的sum()操作,则过程如下:
结果为:
>> torch.sum(y, dim=1)
tensor([[5, 7, 9],
[5, 7, 9],
[5, 7, 9]])
执行dim=2的sum()操作,则过程如下:
结果为:
>> torch.sum(y, dim=2)
tensor([[ 6, 15],
[ 6, 15],
[ 6, 15]])
参考资料:
https://towardsdatascience.com/understanding-dimensions-in-pytorch-6edf9972d3be
.理解numpy中的axis, pytorch中的dim - 知乎
pytorch 基本函数中的 dim【详细说明】:以torch.argmax为例_月下花弄影的博客-CSDN博客_argmax(dim=1)