PyTorch 中的常用张量操作

作者:Samyukta Nair

编译:ronghuaiyang

原文地址:https://mp.weixin.qq.com/s/YM_LoLtjkt2pUta0ikr3UQ

目录:

序号 函数
01 expand()
02 permute()
03 tolist()
04 narrow()
05 where

expand() 函数

  • 作用:将 当前张量 沿着 任意一维 或者 多维 进行 展开 得到新的张量,默认情况下是在 1 维度进行展开,如果不想沿着一个特定的维度进行展开,可以设置参数为 -1
  • 代码实现
import torch 
a = torch.tensor([[[1, 2, 3], [4, 5, 6]]])
print(a.size())

torch.Size([1, 2, 3])

a.expand(2, 2, 3)
print(a)

tensor([[[1, 2, 3],
[4, 5, 6]],

[[1, 2, 3],
[4, 5, 6]]])

  • 说明:在上述案例中,张量的原始维度为 [1, 2, 3],随后扩展成了 [2, 2, 3]

permute() 函数

  • 作用:返回一个张量的视图,原始张量的维度根据我们的选择而改变。我们可以将维度为 [2, 3, 4] 的张量改为 [4, 3, 2] 的张量,函数的参数是维数的顺序
  • 代码实现
a = torch.tensor([[[1, 2, 3], [4, 5, 6]]])
print(a.size())

torch.Size([1, 2, 3])

# 将第二维度和第零维度相换
a.permute(2, 1, 0).size()
print(a.size())

torch.Size([3, 2, 1])

a.permute(2, 1, 0)
print(a)

tensor([[[1],
[4]],

[[2],
[5]],

[[3],
[6]]])

  • 说明:原始张量的维度是 [1, 2, 3],在使用 permute() 函数之后,将输出顺序改成 [2, 1, 0],即我们得到的新的张量的维度为 [3, 2, 1]。如果我们想对不同维度的张量进行重新排序,或者使用不同阶数的矩阵来进行矩阵乘法,可以使用 permute() 函数

tolist() 函数

  • 作用:返回的张量形式为 Python数字列表 或者 嵌套列表,而后可以对返回的类型进行 Python 逻辑操作
  • 代码实现:
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
a.tolist()
print(a)

[[1, 2, 3], [4, 5, 6]]

  • 代码说明:张量最终以列表的形式输出 【张量的表示形式类列表】

narrow() 函数

  • 作用:得到一个新的张量,这个张量是原来张量的缩小版本。narrow(输入张量, 要缩小的维度, 起始索引, 新张量沿该维数的长度),返回从 索引 start 到索引 (start + length - 1) 中的元素。其类似于高级索引
  • 代码实现
a = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [14, 15, 16, 17]])  # torch.Size([4, 4])
print(torch.narrow(a, 1, 2, 2))

tensor([[ 3, 4],
[ 7, 8],
[11, 12],
[16, 17]])

  • 说明:沿着第二维度,从索引2开始,到索引3(3 = 2 + 2 - 1, 即 start + length - 1)

where() 函数

  • 作用:得到一个新的张量,其值在每一个索引出根据给定条件而改变。where(条件, 第一个张量, 第二个张量)。在每个张量对应位置的值进行条件对比,如果为真,则用第一个张量中相同位置的值代替,如果为假,就用第二个张量中相同位置的值代替

  • 代码实现

a = torch.tensor([[[1,2,3], [4,5,6]]]).to(torch.float32)
b = torch.zeros(1, 2, 3)
c = torch.where(a%2==0, b, a)

tensor([[[1., 0., 3.],
[0., 5., 0.]]])

  • 说明:以上代码的条件是检查张量 a 中的值是否是偶数,如果是,就用 张量b中的值代替相应位置,如果不是就用 张量a 的值代替相应位置。此函数可以设置阈值,如果张量的值大于或者小于某一数值,可以更容易被替换

你可能感兴趣的:(PyTorch 中的常用张量操作)