Pytorch与Numpy在数据Indexing的区别

Pytorch又被称为GPU版的Numpy,二者的许多功能都有良好的一一对应。
在这些一一对应中,Indexing是较为模糊的。
例如,我们常常使用Bool的List作为Index,取出Array中的某些行。

a = torch.rand(3, 3)
print(a)

 0.1041  0.6888  0.7988
 0.9398  0.9151  0.7642
 0.5340  0.4715  0.8128
[torch.FloatTensor of size (3,3)]

# 0-试图取出第0行
index = [True, False, False]

# 1-Numpy语义

aa = a.numpy()
print(aa[index])
array([[0.10411686, 0.6887991 , 0.7988465 ]], dtype=float32)

# 2-Pytorch语义
print(a[index])

 0.9398  0.9151  0.7642 #第1行
 0.1041  0.6888  0.7988 #第0行
 0.1041  0.6888  0.7988 #第0行
[torch.FloatTensor of size (3,3)]

# 3-与Numpy一致的Pytorch语义
print(a[torch.tensor(index)])

0.1041  0.6888  0.7988
[torch.FloatTensor of size (1,3)]

 

同时,Pytorch也支持其他Indexing的方式,例如torch.index_select()与torch.mask_select()


# 0-index_select的用法。

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-0.4664,  0.2647, -0.1228, -1.1068],
        [-1.1734, -0.6571,  0.7230, -0.6004]])

# 1-选取第0行与第2行,注意:此处Index是[0,2],而非[True,False,True]。
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)

tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-1.1734, -0.6571,  0.7230, -0.6004]])

# 2-选取第0列与第2列。
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
        [-0.4664, -0.1228],
        [-1.1734,  0.7230]])

# 3-mask_select的用法。

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.3552, -2.3825, -0.8297,  0.3477],
        [-1.2035,  1.2252,  0.5002,  0.6248],
        [ 0.1307, -2.0608,  0.1244,  2.0139]])
>>> mask = x.ge(0.5)
>>> mask
tensor([[ 0,  0,  0,  0],
        [ 0,  1,  1,  1],
        [ 0,  0,  0,  1]], dtype=torch.uint8)
>>> torch.masked_select(x, mask)
tensor([ 1.2252,  0.5002,  0.6248,  2.0139])

 

你可能感兴趣的:(深度学习,python)