Pytorch张量遍历

一、思路

        在pytorch的学习过程中,有时候我们需要对张量进行遍历,那么这是可以的。张量的遍历有许多方法,这里的话我提供两种思路:第一,在张量shape已知的情况下,我们可以通过设置多层for循环来进行遍历;第二,如果不知道shape怎么办呢?此时可以通过张量的reshape操作,然后再用一层for循环来进行遍历。

二、shape已知

        对于一个张量,如果我们知道它的shape,也就是说,shape为几个dim,那么我们就能够很好地对张量进行遍历。例如dim的个数为2,那么这个张量的shape就是(a, b)。这种情况下,我们可以设置两层for循环来对张量进行遍历。具体方式如下:

import torch

# 定义一个张量
a = torch.tensor([[0.0, 0.1, 0.1], [1.0, 1.5, 1.4], [2.1, 2.3, 2.4], [3.9, 3.9, 3.8]])

# 获取张量的shape
dim0, dim1 = a.shape

# 遍历张量
for i in range(dim0):
    for j in range(dim1):
        element = a[i][j]
        print('%.2f' % element.item(), end=' ')

        输出

0.00 0.10 0.10 1.00 1.50 1.40 2.10 2.30 2.40 3.90 3.90 3.80 

三、shape未知

        有小伙伴可能感到疑惑,那如果我都不知道这个张量的形状是多少个dim的呢?那就不可以遍历了?非也,还是能够遍历的,只不过自由度没有那么大了在这种情况下,因为我们不知道要设置多少层for循环。我的想法是把这个张量reshpe一下,把它变成一维的,让遍历交给pytorch处理。最后,我们再通过一层for循环输出一下即可。具体的方式如下:

import torch

# 定义一个张量
a = torch.tensor([[0.0, 0.1, 0.1], [1.0, 1.5, 1.4], [2.1, 2.3, 2.4], [3.9, 3.9, 3.8]])

# 对张量进行reshape,转成一维张量
b = a.reshape(-1)

# 遍历一维张量
for i in b:
    element = i.item()
    print('%.2f' % element, end=' ')

        输出

0.00 0.10 0.10 1.00 1.50 1.40 2.10 2.30 2.40 3.90 3.90 3.80 

        经过对比,这两种遍历方式居然一模一样。按理应该还有其他

四、其他思路

        除了借用torch自带的reshape对张量进行遍历,其实还有一个思路:那就是根据张量在所有元素里面的索引值,然后映射到一个多维的索引值。例如将7映射为(2,1)。具体实现我还没想好,等有空的时候更!

五、备注

        其实可以把张量理解为一颗多叉树,而不是人们平常所说的1维、2维、3维乃至N维数据,说成N维数据的,你会发现这一点都不好理解,当你把一个张量画成一颗树的时候,你会发现,原来张量是那么简单...

你可能感兴趣的:(Python相关,pytorch,python)