关于pytorch用tensor索引另一个tensor问题

1. 问题

在项目中关于如下代码出现问题:

def fun1():
    start_transitions = torch.nn.Parameter(torch.empty(5))
    torch.nn.init.uniform_(start_transitions, -0.1, 0.1)
    tags_b = tags.byte()
    rs = start_transitions[tags_b[0]]
    print(rs)

出现的错误:

The shape of the mask [20] at index 0 does not match the shape of the indexed tensor [5] at index 0

问题:报错原因在哪里?

2. 关于tensor的索引问题

研究支持tensor的索引有哪些类型,运行如下:

def fun2():
    start_transitions = torch.nn.Parameter(torch.empty(5))
    torch.nn.init.uniform_(start_transitions, -0.1, 0.1)
    tags_i = tags.int()
    print(start_transitions[tags_i[0]])

出现如下的报错:

tensors used as indices must be long, byte or bool tensors

结论:tensors的下标必须为long或byte类型。
陷阱:long与type的作用又不一样。

3. 关于byte类型作为下标

 def fun_tyte_indx():
    start_transitions = torch.nn.Parameter(torch.empty(5))
    torch.nn.init.uniform_(start_transitions, -0.1, 0.1)
    tags_b = tags.byte()
    print(start_transitions[tags_b[0][:5]])
    print(start_transitions[tags_b[0]])

对于输出,第一个没有错误,已输出来,可是最后一行出现报错:

tensor([-0.0560, -0.0440,  0.0341, -0.0022, -0.0191], grad_fn=<IndexBackward>)
The shape of the mask [20] at index 0 does not match the shape of the indexed tensor [5] at index 0

结论:Byte类型的下标操作像是一个mask,将原有tensor进行筛选一遍,取出tensor2 对应位置不为0的元素;

4. 关于long类型作为下标

def fun_long_index():
    a = torch.arange(16, 30)
    print('a=', a)
    index_list = [[4, 1, 2], [2, 1, 1]]
    c = torch.LongTensor(index_list)
    print('c:', c)
    print('a[c]:', a[c])
    a = a.view(7, 2)
    print('a_7*2:', a)
    print('a[c]:', a[c])
    print(a.shape, c.shape, a[c].shape)

输出内容:

a= tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
c: tensor([[4, 1, 2],
        [2, 1, 1]])
a[c]: tensor([[20, 17, 18],
        [18, 17, 17]])
a_7*2: tensor([[16, 17],
        [18, 19],
        [20, 21],
        [22, 23],
        [24, 25],
        [26, 27],
        [28, 29]])
a[c]: tensor([[[24, 25],
         [18, 19],
         [20, 21]],

        [[20, 21],
         [18, 19],
         [18, 19]]])
torch.Size([7, 2]) torch.Size([2, 3]) torch.Size([2, 3, 2])

从例子来看,c相当于中将所有的元素替换成a中指定位置的元素;对于多维选择dim=0;
结论:相当于在 tensor2 中将所有的元素替换成tensor1中指定位置的元素;

5. 总结

同样的代码,同样的数值类型,就是由于保存的位数不一样,会产生了不一样的结果。这样对于一直以来使用高级语言数值类型会自动转思维定势要带来一些未知bug产生。需三思。

[happyprince] https://blog.csdn.net/ld326/article/details/105114212

你可能感兴趣的:(pytorch)