在pytorch的学习过程中,有时候我们需要对张量进行遍历,那么这是可以的。张量的遍历有许多方法,这里的话我提供两种思路:第一,在张量shape已知的情况下,我们可以通过设置多层for循环来进行遍历;第二,如果不知道shape怎么办呢?此时可以通过张量的reshape操作,然后再用一层for循环来进行遍历。
对于一个张量,如果我们知道它的shape,也就是说,shape为几个dim,那么我们就能够很好地对张量进行遍历。例如dim的个数为2,那么这个张量的shape就是(a, b)。这种情况下,我们可以设置两层for循环来对张量进行遍历。具体方式如下:
import torch
# 定义一个张量
a = torch.tensor([[0.0, 0.1, 0.1], [1.0, 1.5, 1.4], [2.1, 2.3, 2.4], [3.9, 3.9, 3.8]])
# 获取张量的shape
dim0, dim1 = a.shape
# 遍历张量
for i in range(dim0):
for j in range(dim1):
element = a[i][j]
print('%.2f' % element.item(), end=' ')
输出
0.00 0.10 0.10 1.00 1.50 1.40 2.10 2.30 2.40 3.90 3.90 3.80
有小伙伴可能感到疑惑,那如果我都不知道这个张量的形状是多少个dim的呢?那就不可以遍历了?非也,还是能够遍历的,只不过自由度没有那么大了在这种情况下,因为我们不知道要设置多少层for循环。我的想法是把这个张量reshpe一下,把它变成一维的,让遍历交给pytorch处理。最后,我们再通过一层for循环输出一下即可。具体的方式如下:
import torch
# 定义一个张量
a = torch.tensor([[0.0, 0.1, 0.1], [1.0, 1.5, 1.4], [2.1, 2.3, 2.4], [3.9, 3.9, 3.8]])
# 对张量进行reshape,转成一维张量
b = a.reshape(-1)
# 遍历一维张量
for i in b:
element = i.item()
print('%.2f' % element, end=' ')
输出
0.00 0.10 0.10 1.00 1.50 1.40 2.10 2.30 2.40 3.90 3.90 3.80
经过对比,这两种遍历方式居然一模一样。按理应该还有其他
除了借用torch自带的reshape对张量进行遍历,其实还有一个思路:那就是根据张量在所有元素里面的索引值,然后映射到一个多维的索引值。例如将7映射为(2,1)。具体实现我还没想好,等有空的时候更!
其实可以把张量理解为一颗多叉树,而不是人们平常所说的1维、2维、3维乃至N维数据,说成N维数据的,你会发现这一点都不好理解,当你把一个张量画成一颗树的时候,你会发现,原来张量是那么简单...