indexing
import torch
a = torch.rand(4,3,28,28) # 表示4张28*28的rgb图
print(a[0].shape) # a[0]获得第一张图片
print(a[0,0].shape) # a[0,0]获得第一张图片的r图
print(a[0,0,2,4]) # 获得第一张图片第一个通道的一个像素点,因此得到的是一个标量
select first/last N
# select first/last N
print(a[:2].shape) # :2 => 0,1
print(a[:2,:1,:,:].shape) # :1 => 0
print(a[:2,1:,:,:].shape) # 1: => 1,2
print(a[:2,-1:,:,:].shape) # -1: => 2
select by steps
# select by steps
print(a[:,:,0:28:2,0:28:2].shape) # 0:28:2 => 从0-28,步长为2
print(a[:,:,::2,::2].shape) # ::2 => 从0-28,步长为2
# 总结
# 1. : => all
# 2. :n => 从最开始到n,不包括n
# 3. n: => 从n到最后
# 4. start:end => 从start到end,不包含end
# 5. start:end:steps => 从start到end,不包含end,步长为2
select by specific index
# select by specific index
print(a.index_select(0,torch.tensor([0,2])).shape) # index_select() 第一个参数是维度,第二个参数是具体的索引号,但是索引号必须是tensor,所以要使用torch.tensor()
print(a.index_select(2,torch.arange(28)).shape) # torch.arange(28) => tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
# 18, 19, 20, 21, 22, 23, 24, 25, 26, 27])
...
# ...
print(a[...].shape) # a[...] => a[:,:,:,:]
print(a[0,...].shape) # a[0,...] => a[0,:,:,:]
print(a[:,1,...].shape) # a[:,1,...] => a[:,1,:,:]
print(a[...,:2].shape) # a[...,:2] => a[:,:,:,:2]
select by mask
# select by mask
# .masked_select() 会将数据默认打平->之所以打平是因为当满足某条件的位数是根据内容才能确定的
x = torch.randn(3,4)
mask = x.ge(0.5) # 将大于等于0.5的数取为ture
print(mask) # 掩码
print(torch.masked_select(x,mask)) # 根据掩码取数据和原shape无关
print(torch.masked_select(x,mask).shape)
select by flatten index
# select by flatten index 打平
src = torch.tensor([[4,3,5],[6,7,8]])
print(torch.take(src,torch.tensor([0,2,5]))) # 打平,取索引为0,2,5的数