我们经常需要从2维或3维tensor中进行切片操作,比如从mask模型中取出mask所在位置的向量。
Talk is cheap, show me code.
以下所有维度从0开始,3维即 0,1,2
import torch
x=torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
print(x[[1,2],[0,2]]) # 第0维取 1,2即 [4,5,6],[7,8,9], 在取出的第0维中,分别取第0个和2个,即4,9, 输出 [4,9]
# tensor([4, 9])
# 当第一维不指定值时,表示第一维的每一个都按第二维取值,如
print(x[:,[0,2]]) #会输出第一维每行的第0,2个,即,即 3*2 =6个值
'''
tensor([[1, 3],
[4, 6],
[7, 9]])
'''
print(x[[1,2],[[0],[2]]]) # 第0维取 1,2即[4,5,6],[7,8,9],在这两个中,取两次,第一次取它们的第0个元素,第二次取它们的第2个元素,所以是[4,7],[6,9]
#tensor([[4, 7],
# [6, 9]])
y=torch.tensor([
[[1,2,3],[4,5,6],[7,8,9]],
[[11,12,13],[14,15,16],[17,18,19]],
[[21,22,23],[24,25,26],[27,28,29]]
])
print(y.shape)
print(y[[0,1,2],[0,1,1]]) # 只对第0,1维选择,第2维保持原样,输出
'''
tensor([[ 1, 2, 3],
[14, 15, 16],
[24, 25, 26]])
'''
-----------------------------------------------------
另外,pytorch的函数已经为 这种切片操作准备好了,用以下代码:
batch["loss_ids "] 是 mask的位掩 码,对应的是 如 [0,0,0,1,0,0,0] 这样的形式,值 为1的位置表示是mask。
outputs = outputs[torch.where(batch['loss_ids']>0)]
本质上的操作是 outputs[[1,2,3,4...],[3,3,3,3]]
torhc.where输出 的是二维tensor,第一位是行号,第二位是在对应行的mask位置,直接符合切片操作的要求,直接使用
其它矩阵操作参考: pytorch索引、切片、连接和换位_逝去〃年华的博客-CSDN博客