pytorch: tensors used as indices 用tensor索引另一个tensor

tensor1[tensor2]

刚看到这个结构有点懵,不知道它是具体怎么工作的

example.py


a = torch.arange(16)

b = torch.tensor([2,2,0,1,0,0,1,0,2,1,0,0,1,0,0,0],dtype=torch.uint8)
print(a)
print(b)
print(a[b])

index_list = [[4,3,2,1,0]]
c = torch.LongTensor(index_list)

# print(a)
print(a[c])
print(a.shape,c.shape,a[c].shape)

d = []
for i,index in enumerate(index_list):
    d.append(a[index])
print(d)

'''output
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
tensor([2, 2, 0, 1, 0, 0, 1, 0, 2, 1, 0, 0, 1, 0, 0, 0], dtype=torch.uint8)
tensor([ 0,  1,  3,  6,  8,  9, 12])
tensor([[4, 3, 2, 1, 0]])
torch.Size([16]) torch.Size([1, 5]) torch.Size([1, 5])
[tensor([4, 3, 2, 1, 0])]
'''
索引为torch.uint8类型

可以看到在tensor2bool/uint8类型时,tensor2 更像是一个mask,将原有tensor进行筛选一遍,取出tensor2 对应位置不为0的元素

索引为torch.long类型

这个时候就比较麻烦了,tensor2中存的更像是tensor1中的位置id, 这个时候a[b].shape == b.shape 相当于在 tensor2 中将所有的元素替换成tensor1中指定位置的元素,写了一个替代脚本:

a = torch.arange(16)

index_list = [[4,3,2,1,0]]
c = torch.LongTensor(index_list)

print(a[c])

d = []
for i,index in enumerate(index_list):
    d.append(a[index])
print(d)

# a[c] == d


## 多维的tensor

a = torch.arange(12).view(4,3)

print(a[c])
print(a.shape,c.shape,a[c].shape)

d = []

for i,index in enumerate(index_list):
    d.append(a[index])

print(d)

'''

tensor([[[ 6,  7,  8],
         [ 9, 10, 11],
         [ 6,  7,  8],
         [ 3,  4,  5],
         [ 0,  1,  2]]])
         
torch.Size([4, 3]) torch.Size([1, 5]) torch.Size([1, 5, 3])

[tensor([[ 6,  7,  8],
        [ 9, 10, 11],
        [ 6,  7,  8],
        [ 3,  4,  5],
        [ 0,  1,  2]])]

'''

你可能感兴趣的:(Pytorch_Mxnet)