Pytorch关于高维tensor的dim上操作的理解--以cosine_similarity的dim参数为例

问题的引出

关于pytorch中dim的描述个人总是弄的不是很清楚,好多地方存在着疑问,这次在实验过程中需要比较两个高维tensor的相似度,由于需要确定在哪一维进行比较,故去测试了pytorch中关于tensor维度的一些现象。

dim

关于dim许多博客都有比我更加专业的解释,dim具体的解释也不是本文的重点,这里盗用其他博客里的一张图,这张图也是我认为对dim比较好的直观的解释(原文链接),本文的重点在于对高维tensor维度上操作,即不同的操作在不同的维度上进行会有怎样的不同
Pytorch关于高维tensor的dim上操作的理解--以cosine_similarity的dim参数为例_第1张图片

Cosine_similarity

如果你想对两个tensor比较他们之间的相似度,那么torch.cosine_similarity函数是一个不错的选择,但是在该函数的参数列表中,有一个dim值,在官方文档中,值介绍了这个参数用来指定在哪一维上进行操作,但我在实际使用过程中却对这个概念理解的不好,后来经过不断的实验终于弄懂了dim的含义。

先从二维开始

大部分博客只说明了有关二维的情况,而二维的情况是比较好想的,重点是在高维如三维情况下的tensor,那么这里我们还是从二维开始,先去看一下基本的在维度上的操作
首先直观上我们可以发现,当dim选择在哪一维上操作时,相应的那一维就消失了(这里说的消失,指直观现象,但个人觉得不是特别好理解),

p1 = torch.rand([2,3])
p2 = torch.rand([2,3])
print(p1)
print(p2)
p3 = torch.cosine_similarity(p1,p2,dim=1)
print(p3)
print(p3.shape)

对第一维操作

上述代码作用在第一维上,那么他的结果是怎么样的呢?
Pytorch关于高维tensor的dim上操作的理解--以cosine_similarity的dim参数为例_第2张图片
他的结果有两个元素(个人认为这里讨论维度容易混乱,不如直接说元素的个数),对应原来的p1的shape我们发现在第一维上操作使得第一维消失了,即[2,3]->[2],这也是大多数博客的解释,但我认为这并没有揭示真正的工作过程,同时如果应用到高维的情况,很容易得到一个令人疑惑的维度。
下面让我们来试着理解一下dim的含义,上述的例子中的图片说的已经比较明显了,在dim=1上操作,实际的含义为在以第一维为单位进行操作,即对每一行进行操作(说法不严谨,但为了方便理解),或者也可以这样进行理解就是固定第一维(即tensor的列),去比较第0维(tensor的行)。

对第0维操作

那么按照该思想,如果按照第0维操作,即对每一列为单位进行操作,那么得到的应该是一个有三个元素(为了方便理解,有不严谨的表述)的结果,分别为对应列之间的相似度:
Pytorch关于高维tensor的dim上操作的理解--以cosine_similarity的dim参数为例_第3张图片
通过实验验证,我发现确实如此。

三维的情况

在实际应用中,tensor的形状一般是[batch,seq_len,embed]这样三维的形状,那么在三维中对不同的维度操作会有怎样的差别

对第0维操作

p1 = torch.rand([2,2,3])
p2 = torch.rand([2,2,3])
print(p1)
print(p2)
p3 = torch.cosine_similarity(p1,p2,dim=0)
print(p3)
print(p3.shape)

当取dim=0时,注意此时第0维实际上是batch的维度,则固定batch不动,比较后面的[2,3]的元素,那么后面的是怎么比较的呢?依旧是按照第0维,这里的第0维实际上是后面那个[seq_len,embed]的第0维,即对列进行操作,所以其结果为13三个元素,而原来的tensor有两个batch,所以分别比较后就有23个元素,第一行元素是第一个batch中的[2,3]个元素按照第0位求相似的的结果,同理第二行也是
Pytorch关于高维tensor的dim上操作的理解--以cosine_similarity的dim参数为例_第4张图片

对第一维或第二维操作

这里的过程实际上就和二维的情况一样的,不过需要注意的是二维情况中的第0维对应三维情况的第1维,二维情况的第1维实际上对应三维中的第2维。
不同的是,三维中是一个batch内进行比较,所以,只要在二维操作的基础上加上batch的一维就可以了
Pytorch关于高维tensor的dim上操作的理解--以cosine_similarity的dim参数为例_第5张图片

你可能感兴趣的:(pytorch,python,pytorch,人工智能,机器学习)