torch会自动从左向右索引
例子:
a = torch.randn(4,3,28,28)
表示类似一个CNN 的图片的输入数据,4表示这个batch一共有4张照片,而3表示图片的通道数为3(RGB),(28,28)表示图片的大小
索引1:表示第零张图片的shape
print(a[0].shape)
#torch.Size([3,28,28])
索引2:第零张图片的第零个通道的size
print(a[0,0].shape)
#torch.Size([28,28])
索引3:表示第零张图片的第零个通道的第二行第四列的像素点的值
print(a[0,0,2,4])
#tensor(0.8082)
索引4:连续取两张图片(取第0张以及第一张图片,不包括第二张)
print(a[:2].shape)
#torch.Size([2,3,28,28])
#由于是两张图片,所以第一维变为2
索引5:前两张图片上的第一个通道上的数据(所以通道数变为了1)
print(a[:2,:1,:,:].shape) print(a[:2,:1].shape)
#torch.Size(2,1,28,28)
索引6:从后面取(-1表示最后一个,从最后一个取到最后,也就是一个通道)
print(a[:2,-1:,:,:].shape)
#torch.Size(2,1,28,28)
索引7:在图片的矩阵进行隔行与隔列索引 0:28:2表示从0到28(不包括28),间隔数为2
print(a[:, :, 0:28:2, 0:28:2].shape)
print(a[:, :, :: 2, :: 2].shape)
#torch.Size([4,3,14,14])
start : end : step
:都取
x:从x取到最后 :x 从开始取到x x:y从x取到y
x:y:z从x到y每隔z个点采样一次
使用index_select()函数
第一个参数表示你对哪个维度进行操作;第二个参数是index(必须是tensor类型):对第0张与第2张图片进行操作
a.index_select(0,torch.tensor([0,2])).shape
#【2,3,28,28】
同理:选择了两个通道
a.index_select(1,torch.tensor([1,2])).shape
#【4,2,28,28】
同理:只取8行
a.index_select(2,torch.arange(8)).shape
#【4,2,8,28】
使用符号:
…
例子:
a[…].shape
#[4,3,28,28]
a[0,…].shape
#[3,28,28]
a[0,1,…].shape
#[4,28,28]
a[…,2].shape
#[4,3,28,2]
函数: .masked_select() 会将筛选出来的元素打平(因为无法维护原来的shape)
x = torch.randn(2,3)
print(x)
tensor([[-1.3081, -0.5651, -0.9843],
[ 1.0051, -0.3829, 0.6300]])
mask = x.ge(0.5)#大于等于0.5的元素
print(mask)
tensor([[False, False, False],
[ True, False, True]])
z = torch.masked_select(x,mask)
print(z)
tensor([1.0051, 0.6300])
例子:使用take函数:是将输入的tensor打平之后进行index的选择
src = torch.tensor([[4,3,5],[6,7,8]])
torch.take(src,torch.tensor([0,2,8]))
#tensor([4,5,8])