Pytorch 索引与切片

Pytorch 索引与切片

Indexing
import torch
a = torch.rand(4, 3, 28, 28)
# 对第一个维度进行索引 从最左边开始索引
print("a[0].shape:\t", a[0].shape)
print("a[1].shape:\t", a[1].shape)
# 索引到第二个维度
b = a[0, 0].shape
print('a[0, 0].shape:\t', a[0, 0].shape)
print("a[1, 1].shape:\t", a[0, 0].shape)
# 全部索引
c = a[0, 0, 2, 4]
print("c:\t", c)
a[0].shape:	 torch.Size([3, 28, 28])
a[1].shape:	 torch.Size([3, 28, 28])
a[0, 0].shape:	 torch.Size([28, 28])
a[1, 1].shape:	 torch.Size([28, 28])
c:	 tensor(0.6412)
select first / last N
import torch
a = torch.rand(4, 3, 28, 28)
print("a.shape:\t", a.shape)
# 第一维度上 前两张图片
print("a[:2].shape\t", a[:2].shape)
# 前两张图片的 前1个通道
print("a[:2,:1,:,:].shape:\t", a[:2,:1,:,:].shape)
# 前面两张图片 从第一个通道至末尾
print("a[:2,1:,:,:].shape:\t", a[:2,1:,:,:].shape)
# 前两张图片 从最后一个通道至末尾
print("a[:2,-1:,:,:].shape:\t", a[:2,-1:,:,:].shape)
a.shape:	 torch.Size([4, 3, 28, 28])
a[:2].shape	 torch.Size([2, 3, 28, 28])
a[:2,:1,:,:].shape:	 torch.Size([2, 1, 28, 28])
a[:2,1:,:,:].shape:	 torch.Size([2, 2, 28, 28])
a[:2,-1:,:,:].shape:	 torch.Size([2, 1, 28, 28])
select by steps: start:end:step
import torch
a = torch.rand(4, 3, 28, 28)
# 0:28: 相当于 0:28:1
print("a[:,:,0:28,0:28:2].shape:\t", a[:,:,0:28:2,0:28:2].shape)
print("a[:,:,::2,::2].shape:\t",a[:,:,::2,::2].shape)
a[:,:,0:28,0:28:2].shape:	 torch.Size([4, 3, 14, 14])
a[:,:,::2,::2].shape:	 torch.Size([4, 3, 14, 14])
select by specific index:
a = torch.rand(4, 3, 28, 28)
# a.index_select(0, torch.tensor([0,2])).shape
b = a.index_select(0, torch.tensor([0,2])).shape
print("b:\t", b)
c = a.index_select(1, torch.tensor([1,2])).shape
print("c:\t", c)
d = a.index_select(2, torch.arange(28)).shape
print("d:\t", d)
e = a.index_select(2, torch.arange(8)).shape
print("e:\t", e)
b:	 torch.Size([2, 3, 28, 28])
c:	 torch.Size([4, 2, 28, 28])
d:	 torch.Size([4, 3, 28, 28])
e:	 torch.Size([4, 3, 8, 28])
“…”:代表任意多的维度
import torch
a = torch.rand(4, 3, 28, 28)
b = a[...].shape
print("b:\t", b)
# 想当于 a[0,:,:,:]
c = a[0,...].shape
print("c:\t", c)
d = a[:,1,...].shape
# 相当于 a[:,1,:,:]
print("d:\t", d)
e = a[...,:2].shape
print("e:\t", e)
b:	 torch.Size([4, 3, 28, 28])
c:	 torch.Size([3, 28, 28])
d:	 torch.Size([4, 28, 28])
e:	 torch.Size([4, 3, 28, 2])
select by mask:
import torch
x = torch.randn(3,4)
print("x:\t", x)
# ge() 表示>=0.5 元素的位置 置为1 
mask = x.ge(0.5)
print("mask:\t",mask)
# masked_selcet()取出mask为1对应的元素
y = torch.masked_select(x, mask)
print("y:\t", y)

z =torch.masked_select(x, mask).shape
print("z:\t", z)
x:	 tensor([[ 0.2149, -1.4181,  0.0112,  2.2036],
        [-0.6523,  0.1513,  0.1381,  0.0905],
        [-0.7174, -1.4634, -0.3409, -1.2119]])
mask:	 tensor([[False, False, False,  True],
        [False, False, False, False],
        [False, False, False, False]])
y:	 tensor([2.2036])
z:	 torch.Size([1])
select by flatten index:
import torch
src = torch.tensor([[4, 3, 5], [6, 7, 8]])
print("src:\t", src)
# torch.take() 想当于将shape展平
a = torch.take(src, torch.tensor([0, 2, 5]))
print("a:\t", a)
src:	 tensor([[4, 3, 5],
        [6, 7, 8]])
a:	 tensor([4, 5, 8])

你可能感兴趣的:(Pytorch,深度学习)