pytorch中的矩阵切片操作完全讲解

我们经常需要从2维或3维tensor中进行切片操作,比如从mask模型中取出mask所在位置的向量。

Talk is cheap, show me code. 

以下所有维度从0开始,3维即 0,1,2

import torch 


x=torch.tensor([[1,2,3],[4,5,6],[7,8,9]])

print(x[[1,2],[0,2]])  # 第0维取  1,2即 [4,5,6],[7,8,9], 在取出的第0维中,分别取第0个和2个,即4,9, 输出  [4,9]

# tensor([4, 9])

# 当第一维不指定值时,表示第一维的每一个都按第二维取值,如

print(x[:,[0,2]])   #会输出第一维每行的第0,2个,即,即 3*2 =6个值 

'''
tensor([[1, 3],
        [4, 6],
        [7, 9]])
'''


print(x[[1,2],[[0],[2]]]) # 第0维取 1,2即[4,5,6],[7,8,9],在这两个中,取两次,第一次取它们的第0个元素,第二次取它们的第2个元素,所以是[4,7],[6,9]

#tensor([[4, 7],
#       [6, 9]])



y=torch.tensor([
                [[1,2,3],[4,5,6],[7,8,9]],
                [[11,12,13],[14,15,16],[17,18,19]],
                [[21,22,23],[24,25,26],[27,28,29]]
                
                ])

print(y.shape)

print(y[[0,1,2],[0,1,1]])  # 只对第0,1维选择,第2维保持原样,输出 


'''

tensor([[ 1,  2,  3],
        [14, 15, 16],
        [24, 25, 26]])

'''

pytorch中的矩阵切片操作完全讲解_第1张图片

-----------------------------------------------------

另外,pytorch的函数已经为 这种切片操作准备好了,用以下代码:

batch["loss_ids "] 是 mask的位掩 码,对应的是 如 [0,0,0,1,0,0,0] 这样的形式,值 为1的位置表示是mask。

 outputs = outputs[torch.where(batch['loss_ids']>0)]

本质上的操作是 outputs[[1,2,3,4...],[3,3,3,3]]

torhc.where输出 的是二维tensor,第一位是行号,第二位是在对应行的mask位置,直接符合切片操作的要求,直接使用

其它矩阵操作参考: pytorch索引、切片、连接和换位_逝去〃年华的博客-CSDN博客

你可能感兴趣的:(pytorch,pytorch,矩阵,python)