文章目录
-
- pytorch-张量的索引与切片
-
- 直接索引-indexing
- 高级索引-连续选取-select first/last N
- 高级索引-间隔选取-select by steps
- 高级索引-选取具体索引号-select by specific index
- 高级索引-掩码选取-select by mask
pytorch-张量的索引与切片
直接索引-indexing
>>>a = torch.rand(4, 3, 28, 28)
>>>a.shape
torch.Size([4, 3, 28, 28])
>>>a[0].shape
torch.Size([3, 28, 28])
>>>a[0, 0].shape
torch.Size([28, 28])
高级索引-连续选取-select first/last N
>>>a.shape
torch.Size([4, 3, 28, 28])
>>>a[:2].shape
torch.Size([2, 3, 28, 28])
>>>a[:2, :1, :, :].shape
torch.Size([2, 1, 28, 28])
>>>a[:2, :1].shape
torch.Size([2, 1, 28, 28])
>>>a[:2, 1:].shape
torch.Size([2, 2, 28, 28])
>>>a[:2, -1:].shape
torch.Size([2, 1, 28, 28])
高级索引-间隔选取-select by steps
语法 |
含义 |
: |
all |
:n |
[0, n) |
n : |
[n, END) |
n1 :n2 |
[n1, n2) |
n1 :n2 :step |
从n1开始,每经过step个数取一个;step=1时可以省去 |
>>>a[:, :, 0:28:2, 0:28:2].shape
torch.Size([4, 3, 14, 14])
高级索引-选取具体索引号-select by specific index
>>>a.index_select(0, torch.tensor([0, 2])).shape
torch.Size([2, 3, 28, 28])
>>>a.index_select(1, torch.tensor([1, 2])).shape
torch.Size([4, 2, 28, 28])
>>>a[...].shape
torch.Size([4, 3, 28, 28])
>>>a[0, ...].shape
torch.Size([3, 28, 28])
>>>a[:, 1, ...].shape
torch.Size([4, 28, 28])
>>>a[..., :2].shape
torch.Size([4, 3, 28, 2])
高级索引-掩码选取-select by mask
>>>x = torch.randn(3, 4)
>>>x
tensor([[ 0.4295, 0.4206, -0.5532, -1.4849],
[ 0.4281, 0.0803, -1.1630, 1.4777],
[ 0.1567, -0.9116, -1.7134, -0.1576]])
>>>mask = x.ge(0)
>>>mask
tensor([[ True, True, False, False],
[ True, True, False, True],
[ True, False, False, False]])
>>>torch.masked_select(x, mask)
tensor([0.4295, 0.4206, 0.4281, 0.0803, 1.4777, 0.1567])