目录
索引表达式
一般索引
普通的切片索引
使用step的普通的切片索引
通过index_select接口在特定轴上的index获得元素
任意多的维度
mask索引
take索引
我们可以通过Pytorch Tensor的索引和切片获得tensor一部分数据得到一个新的tensor。通常我们可以通过索引表达式,tensor.index_select接口,任意多的维度表达式,mask索引,take索引从原有的tensor中得到相应的子tensor。
在PyTorch 张量(tensor)的秩,轴,形状(Rank, Axes, and Shape)的理解一文中我们知道每个张量都有若干个轴组成,我们可以通过索引,获取轴上的元素。
一般索引值表达形式是[index, index,...]
举例说明:
假设我们定义这样一个tensor,torch.rand(4, 3, 28, 28) 这个tensor表示minis数据集上的一个batch数据,包含了4张图片,每张图片3个通道(RGB),每张图片的大小是28(行)*28(列)
import torch
a = torch.rand(4, 3, 28, 28)
print(a[0].shape) #取第一张图片的Shape
print(a[0, 0].shape) #取第一张图片的第一个通道(R通道)的Shape
print(a[0, 0, 2, 4]) #取第一张图片,R通道,第三行,第五列的像素值,是一个标量
运行结果
torch.Size([3, 28, 28])
torch.Size([28, 28])
tensor(0.3690)
切片表达式:[start:end:step]
表达式的意义是,从start开始到end结束,每隔step个进行采样。根据start,end,step以及:选项可以分为:
举例说明:
print(a[:2].shape) # 在第一个维度上取后0和1
print(a[:2, :1, :, :].shape) # 在第一个维度上取0和1,在第二个维度上取0
print(a[:2, 1:, :, :].shape) # 在第一个维度上取0和1,在第二个维度上取1,2
print(a[:2, -2:, :, :].shape) # 在第一个维度上取0和1,在第二个维度上取1,2
运行结果:
torch.Size([2, 3, 28, 28])
torch.Size([2, 1, 28, 28])
torch.Size([2, 2, 28, 28])
torch.Size([2, 2, 28, 28])
举例说明:
print(a[:, :, 0:28:2, 0:28:2].shape) # step=2隔行采样
print(a[:, :, ::2, ::2].shape) # 等同于这个
运行结果:
torch.Size([4, 3, 14, 14])
torch.Size([4, 3, 14, 14])
index_select
接口在特定轴上的index获得元素torch.
index_select
(input, dim, index, *, out=None) → Tensor
举例说明:
选择特定下标有时候很有用,比如上面的a这个Tensor可以看作4张RGB(3通道)的MNIST图像,长宽都是28px。那么在第一维度上可以选择特定的图片,在第二维度上选择特定的通道。
# 选择第一张和第三张图
print(a.index_select(0, torch.tensor([0, 2])).shape)
# 选择R通道和B通道
print(a.index_select(1, torch.tensor([0, 2])).shape)
运行结果:
torch.Size([2, 3, 28, 28])
torch.Size([4, 2, 28, 28])
在索引中使用...
可以表示任意多的维度。
举例说明:
print(a[:, 1, ...].shape)
print(a[..., :2].shape)
print(a[0, ..., ::2].shape)
运行结果:
torch.Size([2, 1, 28, 28])
torch.Size([2, 3, 28, 2])
torch.Size([2, 3, 28, 14])
可以获取满足一些条件的值的位置索引,然后用这个索引去取出这些位置的元素。
举例说明:
import torch
# 取出a这个Tensor中大于0.5的元素
a = torch.randn(3, 4)
print(a)
x = a.ge(0.5)
print(x)
print(a[x])
运行结果:
tensor([[ 0.1638, 0.9582, -0.2464, -0.8064],
[ 1.8385, -2.0180, 0.8382, 1.0563],
[ 0.1587, -1.6653, -0.2057, 0.1316]])
tensor([[0, 1, 0, 0],
[1, 0, 1, 1],
[0, 0, 0, 0]], dtype=torch.uint8)
tensor([0.9582, 1.8385, 0.8382, 1.0563])
take索引是基于目标Tensor的flatten形式下的,即摊平后的Tensor的索引。
举例说明:
import torch
a = torch.tensor([[3, 7, 2], [2, 8, 3]])
print(a)
print(torch.take(a, torch.tensor([0, 1, 5])))
运行结果:
tensor([[3, 7, 2],
[2, 8, 3]])
tensor([3, 7, 3])