遇到了用tensor来index另外一个tensor的操作,在Pytorch中右2个比较相似的操作
torch.index_select 和 torch.
gather
torch.
index_select
(input, dim, index, out=None) → Tensor
Returns a new tensor which indexes the input
tensor along dimension dim
using the entries in index
which is a LongTensor.
The returned tensor has the same number of dimensions as the original tensor (input
). The dim
th dimension has the same size as the length of index
; other dimensions have the same size as in the original tensor. (输出tensor的指定dim的维度是index的长度,其他维度不变)
NOTE
The returned tensor does not use the same storage as the original tensor. If out
has a different shape than expected, we silently change it to the correct shape, reallocating the underlying storage if necessary.
Parameters: |
|
---|
Example:
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-0.4664, 0.2647, -0.1228, -1.1068],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
[-0.4664, -0.1228],
[-1.1734, 0.7230]])
torch.
gather
(input, dim, index, out=None, sparse_grad=False) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
If input
is an n-dimensional tensor with size (x_0, x_1..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})(x0,x1...,xi−1,xi,xi+1,...,xn−1) and dim = i
, then index
must be an nn-dimensional tensor with size (x_0, x_1, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})(x0,x1,...,xi−1,y,xi+1,...,xn−1) where y \geq 1y≥1and out
will have the same size as index(输出tensor和index的维度相同,index维度必须只有一维跟input tensor不同)
.
Parameters: |
|
---|
Example:
>>> t = torch.tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1, 1],
[ 4, 3]])
上面似乎都是用一个tensor来index,在一个维度上改变,其他维度保持不变,
似乎也有另外一种用多个tensor共同index的方法,比如:
roi_cls_loc1[t.arange(0, n_sample1).long().cuda(), at.totensor(gt_roi_label1).long()]