PyTorch常用函数 | squeeze与unsqueeze函数 | flatten函数 | Pytorch中的各种乘法 | mul与mv与mm与dot与matmul函数

文章目录

    • 一、squeeze与unsqueeze函数
      • 1.squeeze函数
      • 2.unsqueeze函数
    • 二、flatten函数
    • 三、Pytorch中的各种乘法
      • 1.mul函数
      • 2.mv函数——矩阵向量乘法
      • 3.mm函数——矩阵乘法
      • 4.dot函数——一维向量点积
      • 5.matmul函数

一、squeeze与unsqueeze函数

  顾名思义,squeeze函数的作用是压缩(降维);unsqueeze函数的作用是解压(升维)。

1.squeeze函数

torch.squeeze(input, dim=None, *, out=None)

功能

  • 返回一个所有维度不为1的张量。
    比如,输入张量维度为A×1×B×C×1×D,则经squeeze函数的输出张量维度为A×B×C×D
  • dim参数设置时,只在所设置的dim维度进行张量的压缩,如果所设置的dim维度的尺寸不为1,则张量不会发生任何改变。
    比如,输入张量的维度为A×1×Bsqueeze(input, 0)函数不会使得输出张量的维度有任何改变。但是squeeze(input, 1)函数会使得输出张量的维度变为A×B

需要注意的是:

  1. 返回的张量与输入张量共享内存,因此,其中一个张量的内容改变,另一个张量的内容也会改变
  2. 如果一个张量batch维度设置为1,那么经squeeze函数将会移除batch维度,这会导致错误

参数

  • input(Tensor):输入张量
  • dim(int, optional):如果这个参数被设置,则输入张量仅在这个维度(dim)上进行压缩

例子:

x = torch.zeros(2, 1, 2, 1, 2)
x.size()
=>torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x)
y.size()
=>torch.Size([2, 2, 2])
y = torch.squeeze(x, 0)
y.size()
=>torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x, 1)
y.size()
=>torch.Size([2, 2, 1, 2])

2.unsqueeze函数

torch.unsqueeze(input, dim)

功能

  • 返回一个在指定位置插入尺寸为 1 的新张量
  • dim参数的区间为[-input.dim() - 1, input.dim() + 1),负的dim维度相当于dim = dim + input.dim() + 1

参数

  • input (Tensor):输入张量
  • dim (int):插入的索引

案例

# shape=4
x = torch.tensor([1, 2, 3, 4])  
# shape=(1,4)
torch.unsqueeze(x, 0)
=>tensor([[ 1,  2,  3,  4]])
# shape=(4,1)
torch.unsqueeze(x, 1)
=>t

你可能感兴趣的:(#,PyTorch基础篇,pytorch,深度学习)