关于Pytorch中dim使用的一点记录

pytorch的许多函数,例如torch.cat()、torch.max()、torch.mul()等,都包含了dim参数。关于dim这个函数,我想许多人跟我一样,一知半解,比较模糊,下面我就把自己关于dim的一点看法记录下来,供大家参考,欢迎各位大佬批评指正。

1.关于dim的取值

这一点很多博客都没有说清,大家可能大概知道dim可以取0,1,-1这些值。那么dim还能取其他值吗?具体的规则是什么呢?不同值代表的不同含义又是什么?

A.关于dim 取值范围:假设有一个n维的tensor,那么它的dim可以取值的范围是[-n,n-1]。举个例子:a=torch.randn(2,3,4),很明显,a是一个3维张量,那么它的dim可以取的值有-3,-2,-1,0,1,2。

B.关于不同dim取值代表的含义:还是上面的例子,对于a来说,dim=0意思就是对a的第1个维度进行操作(也就是2所在的那个维度);dim=1意思就是对a的第2个维度进行操作(也就是3所在的那个维度);dim=2意思就是对a的第3个维度进行操作(也就是4所在的那个维度)。

那么对于dim为负数时,代表的意思又是什么呢?对于dim=-1,代表的是张量最里面的那个维度,对于上面例子的a来说,就是代表了第3个维度(也就是4所在的那个维度);对于dim=-2,代表的是a的第2个维度(也就是3所在的那个维度);对于dim=-3,代表的是a的第1个维度(也就是2所在的那个维度)。

综上,以a为例,dim=0和dim=-3代表的含义一样;dim=1和dim=-2代表的含义一样;dim=2和dim=-1代表的含义一样。 

举一反三,对于二维tensor,dim=1和-1一样;dim=0和-2一样。对于一维张量:dim=0和-1结果一样。

2.关于dim的具体操作

知道了dim的取值范围和含义,但是对于指定了dim后如何进行相关操作,应该得到怎么样的结果,我也比较迷糊。所以也查了资料,总结如下:

以三维张量的torch.sum()操作为例:

y = torch.tensor([
     [
       [1, 2, 3],
       [4, 5, 6]
     ],
     [
       [1, 2, 3],
       [4, 5, 6]
     ],
     [
       [1, 2, 3],
       [4, 5, 6]
     ]
   ])

对于y这个三维张量,dim=0是第一维,它可以拆分成3个3个二维张量,那么之后的求和操作就是将这三个二维张量加起来:

关于Pytorch中dim使用的一点记录_第1张图片

得到的结果就是:

>> torch.sum(y, dim=0)
tensor([[ 3,  6,  9],
        [12, 15, 18]])

PS: 此处可以发现:对于size为(3,2,3)的三维张量y来说,如果指定dim=0操作,那么得到的结果一定是size为(2,3)的张量,相当于把第1维砍掉了。那么由此可知,如果指定dim=1操作,那么得到的结果一定是size为(3,3)的张量,相当于把第2维砍掉了,从下面的结果也可以验证。以此类推。

同理,如执行dim=1的sum()操作,则过程如下:

 关于Pytorch中dim使用的一点记录_第2张图片

结果为:

>> torch.sum(y, dim=1)
tensor([[5, 7, 9],
        [5, 7, 9],
        [5, 7, 9]])

 执行dim=2的sum()操作,则过程如下:

关于Pytorch中dim使用的一点记录_第3张图片

结果为:

>> torch.sum(y, dim=2)
tensor([[ 6, 15],
        [ 6, 15],
        [ 6, 15]])

 

参考资料:

https://towardsdatascience.com/understanding-dimensions-in-pytorch-6edf9972d3be

.理解numpy中的axis, pytorch中的dim - 知乎 

pytorch 基本函数中的 dim【详细说明】:以torch.argmax为例_月下花弄影的博客-CSDN博客_argmax(dim=1)

 

你可能感兴趣的:(Python学习,python,经验分享)