Pytorch又被称为GPU版的Numpy,二者的许多功能都有良好的一一对应。 在这些一一对应中,Indexing是较为模糊的。 例如,我们常常使用Bool的List作为Index,取出Array中的某些行。
a = torch.rand(3, 3)
print(a)
0.1041 0.6888 0.7988
0.9398 0.9151 0.7642
0.5340 0.4715 0.8128
[torch.FloatTensor of size (3,3)]
# 0-试图取出第0行
index = [True, False, False]
# 1-Numpy语义
aa = a.numpy()
print(aa[index])
array([[0.10411686, 0.6887991 , 0.7988465 ]], dtype=float32)
# 2-Pytorch语义
print(a[index])
0.9398 0.9151 0.7642 #第1行
0.1041 0.6888 0.7988 #第0行
0.1041 0.6888 0.7988 #第0行
[torch.FloatTensor of size (3,3)]
# 3-与Numpy一致的Pytorch语义
print(a[torch.tensor(index)])
0.1041 0.6888 0.7988
[torch.FloatTensor of size (1,3)]
同时,Pytorch也支持其他Indexing的方式,例如torch.index_select()与torch.mask_select()
# 0-index_select的用法。
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-0.4664, 0.2647, -0.1228, -1.1068],
[-1.1734, -0.6571, 0.7230, -0.6004]])
# 1-选取第0行与第2行,注意:此处Index是[0,2],而非[True,False,True]。
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-1.1734, -0.6571, 0.7230, -0.6004]])
# 2-选取第0列与第2列。
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
[-0.4664, -0.1228],
[-1.1734, 0.7230]])
# 3-mask_select的用法。
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.3552, -2.3825, -0.8297, 0.3477],
[-1.2035, 1.2252, 0.5002, 0.6248],
[ 0.1307, -2.0608, 0.1244, 2.0139]])
>>> mask = x.ge(0.5)
>>> mask
tensor([[ 0, 0, 0, 0],
[ 0, 1, 1, 1],
[ 0, 0, 0, 1]], dtype=torch.uint8)
>>> torch.masked_select(x, mask)
tensor([ 1.2252, 0.5002, 0.6248, 2.0139])