Pytorch基础(二)Tensor的索引和切片

Pytorch基础(二)Tensor的索引和切片

Tensor的index和select

  • **Dim 0 first:**多维张量的索引默认为第一个维度索引
a = torch.Tensor(4, 3, 28, 28)
print(a[0].shape) # torch.Size([3,28,28])
print(a[0,0].shape) # troch.Size([28,28])
  • 选择前N个或后N个
  1. :代表全部
  2. n:代表从第n个到最后(包括第n个)
  3. :n代表从第一个到第n个(不包括第n个)
  4. n:m代表从第n个到第m个(包括第n个,不包括第m个)
  5. n:m:x代表从第n个到第m个,每隔x个取一个(包括第n个,不包括第m个)
  6. 通用形式为:start:end:step​(不包括end)
a = torch.Tensor(4, 3, 28, 28)
print(a[:2].shape) # torch.Size([2,3,28,28]) 这里:2代表从0到2(不包括2)
print(a[2:].shape) # torch.Size([2,3,28,28]) 这里2:代表从2到最后(包括2)
print(a[-2:].shape) # torch.Size([2,3,28,28]) 这里-2:代表从倒数第二个到最后(包括倒数第二个)
print(a[:].shape) # torch.Size([4,3,28,28]) 这里:代表这维度的所有元素
print(a[:,:,0:14]) # torch.Size([4,3,14,28]) 这里0:14代表从0到14(不包括14)
print(a[:,:,0:28:2])# torch.Size([4,3,14,28]) 这里0:28:2代表从0到28(不包括28),每两个取一次
  • 选择特定的维度.index_select(dim, index)
a = torch.Tensor(4, 3, 28, 28)
print(a.index_select(0,torch.tensor([2, 3])).shape) # 沿着第0个维度进行切片,取第2和第3个tensor。
print(a.index_select(2,torch.arange(14)).shape) # 沿着第2个维度切片,取前14个tensor

注意,index这个参数必须是torch.tensor不能使用python中的list

  • 用省略号...代表任意维度

这里的...代表维度需要根据具体情况进行推测

所以这里的...必须是可以推测出的维度,比如最左/右或中间的维度

a = torch.Tensor(4, 3, 28, 28)
a[...].shape # torch.Size([4,3,28,28]) 这里代表所有维度
a[:,1,...].shape # torch.Size([4,28,28]) 这里最右边的所有维度
  • 通过mask(掩码)来进行筛选torch.masked_select()

注意,使用torch.masked_select()会将数据的维度打平,返回的tensor维度为1,长度不定

x = torch.randn(3,4)
print(x)
# out:
# tensor([[ 0.6797, -0.1078,  0.7623,  0.2214],
#         [-1.2354,  0.6120,  2.3871, -1.1993],
#         [-0.2460, -1.2034,  0.7166,  0.2186]])
mask = x.ge(0.5)
print(mask)
# out:
# tensor([[ True, False,  True, False],
#         [False,  True,  True, False],
#         [False, False,  True, False]])
a = torch.masked_select(x,mask)
print(a)
# tensor([0.6797, 0.7623, 0.6120, 2.3871, 0.7166])
print(a.shape)
# torch.Size([5])
  • 通过将tensor的维度打平来进行selecttorch.take()
a = torch.tensor([[4,3,5]
                 [6,7,8]])
torch.take(a,torch.tensor([0,2])) # tensor([4,5,8])

Tensor的维度变换

常用的API:

  • 形状变换:View/reshape

在pytorch0.3之后,view和reshape这两个函数功能完全相同,但要确保前后的numel一致

适合全连接层将维度打平时使用

但在维度改变时,会丢失原先tensor各个维度的意义。例如[b,c,h,w]打平后会破坏各个维度的顺序,还原后顺序会被改变

  • 挤压和扩充:Squeeze/unsqueeze

squeeze只能将shape为1的维度压缩

unsqueeze不会改变具体的数据,只是将tensor的维度进行扩充,被扩充的维度的shape还是1

  • 转置:Transpose/t/permute

transpose只能将其中某两个维度进行交换

要多个维度交换可以使用多次transpose,也可以直接使用一次permute

不论是transpose还是permute,维度交换后内存地址还是不连续的,一般可以后面添加.contiguous()使内存地址连续,这样再进行其他操作减小了报错的可能性

  • 维度扩展:Expand/repeat

你可能感兴趣的:(机器学习)