pytorch学习笔记(五)——Tensor索引

Tensor索引 

dim 0 开始索引

a = torch.rand(4,3,28,28)
a[0].shape  # 第0张图片 torch.Size([3,28,28])

a[0,0].shape  #取第0张图片的第0个通道 torch.Size([28,28])

a[0,0,2,4].shape  #取第0张图片的第0个通道的第2行第4列的像素点 tensor(0.8082)

取前/后几个图片

a[:2].shape  #取前两张图片 torch.Size([2,3,28,28])

a[:2,:1,:,:].shape #取前两张图片第一个通道上的所有数据 torch.Size([2,1,28,28])

a[:2,1:,:,:].shape #取前两张图片后两个通道上的所有数据 torch.Size([2,2,28,28])

a[:2,-1:,:,:].shape #取前两张图片最后一个通道上的所有数据 torch.Size([2,1,28,28])

按间隔选取

a[:,:,0:28:2,0:28:2].shape  #对所有图片所有通道的长和宽隔行采样 torch.Size([4,3,14,14])

a[:,:,::2,::2].shape  #对所有图片所有通道的长和宽隔行采样 torch.Size([4,3,14,14])

选取具体索引值

a.index_select(0,torch.tensor([0,2])).shape # 选第0张和第2张图片的所有通道所有长宽torch.Size([2,3,28,28])

a.index_select((1,torch.tensor([1,2])).shape # 选G,B通道所有图片所有长宽 torch.Size([4,2,28,28])

a.index_select((2,torch.arange(8)).shape #所有图片所有通道8行所有宽 torch.Size([4,3,8,28])

...代表所有维度取满

a[...].shape  # torch.Size([4,3,28,28])

a[0,...].shape  # torch.Size([3,28,28])

a[:,1,...].shape  # torch.Size([4,28,28])

a[...,:2].shape  # torch.Size([4,3,28,2])

.masked_select()  会被打平

pytorch学习笔记(五)——Tensor索引_第1张图片取>0.5的元素

打平索引

b = torch.tensor([[4,3,5],
                  [6,7,8]])
torch.take(b,torch.tensor([0,2,5]))
# tensor([4,5,8]) 结果返回的是打平后的索引的值

你可能感兴趣的:(pytorch,python,数据分析,pytorch,深度学习)