torch.gather

https://pytorch.org/docs/stable/generated/torch.gather.html

一个简单的例子:

t = torch.rand(2,3)
"""
tensor([[0.8133, 0.5586, 0.7917],
        [0.0551, 0.2322, 0.9087]])
"""
t.gather(dim=0,index=torch.tensor([[0,1,0],[1,0,1]]))
"""
tensor([[0.8133, 0.2322, 0.7917],
        [0.0551, 0.5586, 0.9087]])
"""
  • dim = 0,说明index中所有索引均是索引行。
  • 关于index的shape:dim=1时,input.shape[0]=index.shape[0], 同理, 可推dim=0时,input.shape[1]=index.shape[1]
# 常用于以下需求:
# celoss = torch.tensor([i_s[i_t] for i_s,i_t in zip(softmax,target)])

input = torch.randn(3, 5, requires_grad=True) # (3,5)

n_samples = input.shape[0] # 注意dim=1时,input.shape[0]=index.shape[0], 同理, 可推dim=0时,input.shape[1]=index.shape[1]
channel = 6

idx = torch.randint(low=0,high=5,size=(n_samples*channel,)).reshape(n_samples,channel)
"""
tensor([[0, 0, 4, 2, 3, 1], 第一行取第0个,第0个,第4个...
        [3, 3, 1, 0, 2, 2], 第二行取第3个,第3个,第1个...
        [4, 4, 4, 2, 1, 3]]) ...
"""
input.gather(dim=1,index=idx) # torch.Size([3, 6])

你可能感兴趣的:(torch.gather)