torch.index_select(input, dim, index, *, out=None) → Tensor
功能:选择根据给定的index和dim在input中选择张量数据,相当于更高级的索引功能。
参数:
注意:返回的张量数组与原始的张量数组 具有相同的维数,这里与直接进行索引有区别
import torch
a=torch.arange(40).view(2,4,5)
index=torch.tensor([1,3])
select_1=torch.index_select(a,dim=1,index=index)
print(select_1)
# tensor([[[ 5, 6, 7, 8, 9],
# [15, 16, 17, 18, 19]],
#
# [[25, 26, 27, 28, 29],
# [35, 36, 37, 38, 39]]])
select_2=torch.index_select(a,dim=2,index=index)
print(select_2)
# tensor([[[ 1, 3],
# [ 6, 8],
# [11, 13],
# [16, 18]],
#
# [[21, 23],
# [26, 28],
# [31, 33],
# [36, 38]]])
import torch
torch.manual_seed(100)
x = torch.randn(2, 3)
print(x)
# 用 bool 值表示,是否大于0
mask=x>0
print('\n', mask)
# 获取大于0的值
torch.masked_select(x, mask)
# 获取非0下标索引, True为1,False为0
torch.nonzero(mask)
print('\n', torch.nonzero(mask))
import torch
torch.manual_seed(100)
x = torch.randn(2, 3)
print(x)
index = torch.LongTensor([[0, 1, 1]])
# 按照 index 从 x 中取值,输出与 index 的形状相同
torch.gather(x, 0, index)
import torch
torch.manual_seed(100)
x = torch.randn(2, 3)
print(x)
index = torch.LongTensor([[0, 1, 1],
[1, 1, 1]])
# 按照 index 从 x 中取值,输出与 index 的形状相同
torch.gather(x, 1, index)
条件:a 的形状必需和 index 保持一致,否则会报错
作用:dest_tensor.scatter(dim, index, src_tensor)
src_tensor 和 dest_tensor 是两个张量
a = torch.arange(6).view(2, 3).to(torch.float32)
print('a: \n', a)
index = torch.tensor([[0, 2, 2],
[2, 1, 0]])
print('index: \n', index, '\n')
b = torch.zeros((3, 3))
print('b 在执行 scatter 之前\n', b, '\n')
b.scatter_(0, index, a)
print('b 在执行 scatter 之后 (dim =0)\n', b)