In [130]:a=torch.rand(4,3,28,28)
In [131]: a[0].shape
0ut[131]: torch.Size([3, 28, 28])
In [138]: a[0,0].shape
0ut[138]: torch.Size([28, 28])
In [139]: a[0,0,2,4]
0ut[139]: tensor(0. 8082)
In [140]: a.shape
0ut[140]: torch.Size([4, 3, 28, 28])
In [141]: a[:2].shape//第0个维度从0到1
0ut[141]: torch.Size([2, 3, 28, 28])
In [142]: a[:2,:1,:,:].shape
0ut[142]: torch.Size([2,1, 28, 28])
In [143]: a[:2,1:,:,:].shape
Out[143]: torch.Size([2, 2,28, 28])
In [144]: a[:2,-1:,:,:].shape//从最后一个元素到末尾
0ut[144]: torch.Size([2, 1, 28, 28])
In [145]: a[:,:,0:28:2,0:28:2].shape//从0到28间隔2选择
Out[145]: torch.Size([4, 3, 14, 14])
In [146]: a[:,:,::2,::2].shape//从头到尾间隔step步长选择
Out[146]: torch.Size([4, 3,14, 14])
//通用形式:start:end:step
index_select()
In [149]: a.shape
0ut[149]: torch.Size([4, 3, 28, 28])
In [159]: a.index_select(0, torch.sensor([0, 2])).shape
0ut[159]: torch.Size([2, 3, 28, 28])
In [159]: a.index_select(1, torch.sensor([1, 2])).shape
0ut[159]: torch.Size([4, 2, 28, 28])
In [168]: a.index_select(2, torch.arange(8)).shape
0ut[168]: torch.Size([4, 3, 8, 28])
//.index_select第二个参数不能以list的形式直接输入,需要以tensor的形式输入
…
In [149]: a.shape
Out[149]: torch.Size([4, 3, 28, 28])
In [150]: a[...].shape
0ut[150]: torch.Size([4, 3, 28, 28])
//a[...] 相当于 a[0] 相当于 a[:,:,:,:]
In [151]: a[0,...].shape
0ut[151]: torch.Size([3, 28, 28])
In [152]: a[:,1,...].shape
0ut[152]: torch.Size([4, 28, 28])
In [155]: a[..., :2].shape
0ut[155]: torch.Size([4, 3, 28, 2])
masked_select()
In [170]: x = torch.randn(3, 4)
tensor([[-1.3911, -0.7871, -1.6558, -0.2542],
[-0.9011, 0.5404, -0.6612, 0.3917],
[-0.3854, 0.2968, 0.6040, 1.5771]])
In [172]: mask = x.ge(0.5)
tensor([[0, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 1]], dtype=torch.uint8)
In [174]: torch.masked_select(x, mask)
0ut[174]: tensor([0.5404, 0.6040, 1.5771])
//masked_select(x, mask)按照mask的掩码选择x中对应索引的元素
In [175]: torch.masked_select(x, mask).shape
0ut[175]. torch.Stze([3])