PyTorch 索引与切片

indexing

#从第0维往后排
a = torch.rand(4,3,28,28)
print(a[0].shape)
print(a[0,0].shape)
print(a[0,0,0].shape)

print(a[0,0,0,0])

从前或者后面全取 

#从第0维往后排
a = torch.rand(4,3,28,28)
#取最前面的
print(a[:2].shape)
print(a[:2,:1,:,:].shape)
#取最后面的
print(a[2:,1:,:,:].shape)
print(a[2:,-1:,:,:].shape)#-1表示倒数第一
print(a[:,:,::2,::2].shape)

1、:单独出现表示取全部

2、:n表示,从0到n

3、n:表示从n到最后

4、n:m,表示从n到m,不包括m

4、n:m:k,表示从n到m,不包括m,隔行采样,间隔k取一个

特殊的选择某区间

#从第0维往后排,第二个参数必须是tensor
a = torch.rand(4,3,28,28)
print(a.index_select(0,torch.tensor([0,2])).shape)
print(a.index_select(1,torch.tensor([0,2])).shape)

使用...

#从第0维往后排  ...表示剩余的任意长
a = torch.rand(4,3,28,28)
print(a[...].shape)
print(a[0,...].shape)
print(a[...,:2].shape)

select by mask,不建议使用,会把数据默认打平

x = torch.randn(3,4)
print(x)

mask = x.ge(0.5)#大于0.5处为true
print(mask)

print(torch.masked_select(x,mask))
print(torch.masked_select(x,mask).shape)

select by flatten index

也会进行打平,比如查找a[2][3]中最后一个用下标5

x = torch.tensor([ [4,3,5],[6,7,8] ])
print(torch.take(x,torch.tensor([0,2,5])))

 

你可能感兴趣的:(pytorch,机器学习)