pytorch中获取指定位置元素

  这段代码的应用场景是:某个batch的sentence,有的经过了padding操作,如果获取每句话中实际的最后一个单词。

A = torch.Tensor([[[2, 3, 1], [1, 4, 0], [1, 0, 0]], [[2, 2, 0], [2, 0, 0], [3, 1, 4]]])
print(A.size())

B = torch.Tensor([[3, 2, 1], [2, 1, 3]]).long()
print(B.size())
B = B.view(2, 3, -1)
B = B - 1

C = torch.gather(A, 2, B)
print(C)

你可能感兴趣的:(Python,NLP)