Tensor索引
dim 0 开始索引
a = torch.rand(4,3,28,28)
a[0].shape # 第0张图片 torch.Size([3,28,28])
a[0,0].shape #取第0张图片的第0个通道 torch.Size([28,28])
a[0,0,2,4].shape #取第0张图片的第0个通道的第2行第4列的像素点 tensor(0.8082)
取前/后几个图片
a[:2].shape #取前两张图片 torch.Size([2,3,28,28])
a[:2,:1,:,:].shape #取前两张图片第一个通道上的所有数据 torch.Size([2,1,28,28])
a[:2,1:,:,:].shape #取前两张图片后两个通道上的所有数据 torch.Size([2,2,28,28])
a[:2,-1:,:,:].shape #取前两张图片最后一个通道上的所有数据 torch.Size([2,1,28,28])
按间隔选取
a[:,:,0:28:2,0:28:2].shape #对所有图片所有通道的长和宽隔行采样 torch.Size([4,3,14,14])
a[:,:,::2,::2].shape #对所有图片所有通道的长和宽隔行采样 torch.Size([4,3,14,14])
选取具体索引值
a.index_select(0,torch.tensor([0,2])).shape # 选第0张和第2张图片的所有通道所有长宽torch.Size([2,3,28,28])
a.index_select((1,torch.tensor([1,2])).shape # 选G,B通道所有图片所有长宽 torch.Size([4,2,28,28])
a.index_select((2,torch.arange(8)).shape #所有图片所有通道8行所有宽 torch.Size([4,3,8,28])
...代表所有维度取满
a[...].shape # torch.Size([4,3,28,28])
a[0,...].shape # torch.Size([3,28,28])
a[:,1,...].shape # torch.Size([4,28,28])
a[...,:2].shape # torch.Size([4,3,28,2])
.masked_select() 会被打平
打平索引
b = torch.tensor([[4,3,5],
[6,7,8]])
torch.take(b,torch.tensor([0,2,5]))
# tensor([4,5,8]) 结果返回的是打平后的索引的值