Pytorch 索引与切片
Indexing
import torch
a = torch.rand(4, 3, 28, 28)
print("a[0].shape:\t", a[0].shape)
print("a[1].shape:\t", a[1].shape)
b = a[0, 0].shape
print('a[0, 0].shape:\t', a[0, 0].shape)
print("a[1, 1].shape:\t", a[0, 0].shape)
c = a[0, 0, 2, 4]
print("c:\t", c)
a[0].shape: torch.Size([3, 28, 28])
a[1].shape: torch.Size([3, 28, 28])
a[0, 0].shape: torch.Size([28, 28])
a[1, 1].shape: torch.Size([28, 28])
c: tensor(0.6412)
select first / last N
import torch
a = torch.rand(4, 3, 28, 28)
print("a.shape:\t", a.shape)
print("a[:2].shape\t", a[:2].shape)
print("a[:2,:1,:,:].shape:\t", a[:2,:1,:,:].shape)
print("a[:2,1:,:,:].shape:\t", a[:2,1:,:,:].shape)
print("a[:2,-1:,:,:].shape:\t", a[:2,-1:,:,:].shape)
a.shape: torch.Size([4, 3, 28, 28])
a[:2].shape torch.Size([2, 3, 28, 28])
a[:2,:1,:,:].shape: torch.Size([2, 1, 28, 28])
a[:2,1:,:,:].shape: torch.Size([2, 2, 28, 28])
a[:2,-1:,:,:].shape: torch.Size([2, 1, 28, 28])
select by steps: start:end:step
import torch
a = torch.rand(4, 3, 28, 28)
print("a[:,:,0:28,0:28:2].shape:\t", a[:,:,0:28:2,0:28:2].shape)
print("a[:,:,::2,::2].shape:\t",a[:,:,::2,::2].shape)
a[:,:,0:28,0:28:2].shape: torch.Size([4, 3, 14, 14])
a[:,:,::2,::2].shape: torch.Size([4, 3, 14, 14])
select by specific index:
a = torch.rand(4, 3, 28, 28)
b = a.index_select(0, torch.tensor([0,2])).shape
print("b:\t", b)
c = a.index_select(1, torch.tensor([1,2])).shape
print("c:\t", c)
d = a.index_select(2, torch.arange(28)).shape
print("d:\t", d)
e = a.index_select(2, torch.arange(8)).shape
print("e:\t", e)
b: torch.Size([2, 3, 28, 28])
c: torch.Size([4, 2, 28, 28])
d: torch.Size([4, 3, 28, 28])
e: torch.Size([4, 3, 8, 28])
“…”:代表任意多的维度
import torch
a = torch.rand(4, 3, 28, 28)
b = a[...].shape
print("b:\t", b)
c = a[0,...].shape
print("c:\t", c)
d = a[:,1,...].shape
print("d:\t", d)
e = a[...,:2].shape
print("e:\t", e)
b: torch.Size([4, 3, 28, 28])
c: torch.Size([3, 28, 28])
d: torch.Size([4, 28, 28])
e: torch.Size([4, 3, 28, 2])
select by mask:
import torch
x = torch.randn(3,4)
print("x:\t", x)
mask = x.ge(0.5)
print("mask:\t",mask)
y = torch.masked_select(x, mask)
print("y:\t", y)
z =torch.masked_select(x, mask).shape
print("z:\t", z)
x: tensor([[ 0.2149, -1.4181, 0.0112, 2.2036],
[-0.6523, 0.1513, 0.1381, 0.0905],
[-0.7174, -1.4634, -0.3409, -1.2119]])
mask: tensor([[False, False, False, True],
[False, False, False, False],
[False, False, False, False]])
y: tensor([2.2036])
z: torch.Size([1])
select by flatten index:
import torch
src = torch.tensor([[4, 3, 5], [6, 7, 8]])
print("src:\t", src)
a = torch.take(src, torch.tensor([0, 2, 5]))
print("a:\t", a)
src: tensor([[4, 3, 5],
[6, 7, 8]])
a: tensor([4, 5, 8])