torch 中的 dim 参数,以及 torch.repeat_interleave() 函数

torch 中的 dim 参数,以及torch.repeat_interleave()

张量的维度:用 dim 描述

在线性代数课程中,接触的最多的是以行和列对矩阵进行描述,但是在拥有更多维度的张量中则有所不同

我们使用 torch 生成一个张量

y = torch.tensor([[1, 2], [3, 4]]) 

利用数学形式的表示,我们可以写作
y = [ 1 3 2 4 ] y = \begin{bmatrix}1 & 3 \\ 2 & 4\end{bmatrix} y=[1234]
或者以 torch 的形式描述为
y = [ [ 1 2 ] [ 3 4 ] ] y = \begin{bmatrix} \begin{bmatrix} 1 \\2\end{bmatrix}\begin{bmatrix}3 \\4\end{bmatrix}\end{bmatrix} y=[[12][34]]
第一层括号记作 dim=0,遍历 dim=0 的层次,得到两个张量
y 1 = [ 1 2 ] y 2 = [ 3 4 ] y_1 = \begin{bmatrix} 1 \\ 2\end{bmatrix} \\ y_2 =\begin{bmatrix} 3 \\ 4\end{bmatrix} y1=[12]y2=[34]
第二重括号记作 dim=1,遍历 dim=1 的层次,得到四个数字
y 11 = 1 y 12 = 2 y 21 = 3 y 22 = 4 y_{11} = 1 \\y_{12} =2 \\y_{21} = 3 \\ y_{22} = 4 y11=1y12=2y21=3y22=4

张量的复制:torch.repeat_interleave()

首先复制一下 PyTorch 官方文档的描述

torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) → Tensor

参数有三个

  • input:待处理的 tensor
  • repeat:每个元素重复几次
  • dim:在那个维度上重复

看几个例子,以上一节中的 y y y 为例

dim=0

 torch.repeat_interleave(y, 2, dim=0)

dim=0,在 dim=0,即所有张量上进行操作,每个一维 tensor 复制 1 次

tensor([[1, 2],
        [1, 2],
        [3, 4],
       	[3, 4]])

dim=1

torch.repeat_interleave(y, 3, dim=1)

dim=1 ,即每个一维张量内部元素上进行操作,所有元素复制 2 次

tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])

不提供 dim 参数

y.repeat_interleave(2)

将默认返回 y.Flatten()

tensor([1, 1, 2, 2, 3, 3, 4, 4])

repeat 参数本身也能是一个 tensor

torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)

dim=0,即所有张量上进行操作,第一个 tensor 复制 0 次,第二个复制 1 次

tensor([[1, 2],
        [3, 4],
        [3, 4]])

最后贴上官方文档的链接

你可能感兴趣的:(Numpy快速上手教程,pytorch,python,深度学习)