偶然发现pytorch的tensor除了像numpy数组那样简单索引或者切片,还有一种花式索引,也就是用tensor对tensor索引
,可以广播原tensor。下面给出示例和转为numpy版本的写法。
i n a . s h a p e = [ b , c , h , w ] in_a.shape=[b,c,h,w] ina.shape=[b,c,h,w]
i n b . s h a p e = [ m , n ] in_b.shape= [m,n] inb.shape=[m,n]
采用in_b对in_a索引: o u t = a [ : , : , b , : ] out = a[:,:, b,:] out=a[:,:,b,:]
则得到的out的shape: o u t . s h a p e = [ b , c , m , n , w ] out.shape=[b,c,m,n,w] out.shape=[b,c,m,n,w]
举个例子:
>>> in_a = torch.randn(1,1,4,5)
>>> in_b = torch.tensor([[2,0],[1,3],[2,3]])
>>> in_a
tensor([[[[ 0.2668, 0.5453, 0.5563, 0.7396, -1.1646],
[-0.1059, 0.8955, 0.8947, -3.0298, -2.0912],
[ 0.8145, 0.3670, 0.4827, 0.1327, -0.9437],
[ 1.3698, -0.8281, -0.8810, 1.6670, -1.8736]]]])
>>> in_b.shape
torch.Size([3, 2])
>>> in_a[:,:,in_b,:].shape
torch.Size([1, 1, 3, 2, 5])
>>> in_a[:,:,in_b,:]
tensor([[[[[ 0.8145, 0.3670, 0.4827, 0.1327, -0.9437],
[ 0.2668, 0.5453, 0.5563, 0.7396, -1.1646]],
[[-0.1059, 0.8955, 0.8947, -3.0298, -2.0912],
[ 1.3698, -0.8281, -0.8810, 1.6670, -1.8736]],
[[ 0.8145, 0.3670, 0.4827, 0.1327, -0.9437],
[ 1.3698, -0.8281, -0.8810, 1.6670, -1.8736]]]]])
也就是在 i n _ a in\_a in_a的 d i m = 2 dim= 2 dim=2 上索引,依次取index= [ 2 , 0 ] , [ 1 , 3 ] , [ 2 , 3 ] [2,0],[1,3],[2,3] [2,0],[1,3],[2,3]的tensor填充。特别要注意:index的数值不能超出dim=2的最大维度, 比如例子中,in_a的shape为 [ 1 , 1 , 4 , 5 ] [1,1,4,5] [1,1,4,5],在dim=2维度索引, 则索引的值只能是 0 , 1 , 2 , 3 0,1,2,3 0,1,2,3.
再举个栗子:
>>> in_a[:,:,:,in_b].shape
torch.Size([1, 1, 4, 3, 2])
>>> in_a[:,:,:,in_b]
tensor([[[[[ 0.5563, 0.2668],
[ 0.5453, 0.7396],
[ 0.5563, 0.7396]],
[[ 0.8947, -0.1059],
[ 0.8955, -3.0298],
[ 0.8947, -3.0298]],
[[ 0.4827, 0.8145],
[ 0.3670, 0.1327],
[ 0.4827, 0.1327]],
[[-0.8810, 1.3698],
[-0.8281, 1.6670],
[-0.8810, 1.6670]]]]])
目前只想到很愚蠢的遍历读取再赋值:
import numpy as np
num_a = in_a.numpy() # [1,1,4,5]
num_b = in_b.numpy() # [3,2]
[b,c,h,w] = num_a.shape
[m,n] = num_b.shape
out_ny = np.zeros([b,c,m,n,w]) # [1,1,3,2,5]
for i in range(m):
for j in range(n):
out_ny[:,:,i,j,:] = num_a[:,:, num_b[i,j],:]
out_ny
array([[[[[ 0.81448293, 0.36703789, 0.48273084, 0.13274327,
-0.94368148],
[ 0.26677063, 0.54529017, 0.55633378, 0.73956281,
-1.16463828]],
[[-0.10586801, 0.89547068, 0.89467597, -3.02978396,
-2.09123206],
[ 1.36978781, -0.8280825 , -0.8810119 , 1.6670413 ,
-1.87361884]],
[[ 0.81448293, 0.36703789, 0.48273084, 0.13274327,
-0.94368148],
[ 1.36978781, -0.8280825 , -0.8810119 , 1.6670413 ,
-1.87361884]]]]])