torch.repeat_interleave()
dim
描述在线性代数课程中,接触的最多的是以行和列对矩阵进行描述,但是在拥有更多维度的张量中则有所不同
我们使用 torch 生成一个张量
y = torch.tensor([[1, 2], [3, 4]])
利用数学形式的表示,我们可以写作
y = [ 1 3 2 4 ] y = \begin{bmatrix}1 & 3 \\ 2 & 4\end{bmatrix} y=[1234]
或者以 torch 的形式描述为
y = [ [ 1 2 ] [ 3 4 ] ] y = \begin{bmatrix} \begin{bmatrix} 1 \\2\end{bmatrix}\begin{bmatrix}3 \\4\end{bmatrix}\end{bmatrix} y=[[12][34]]
第一层括号记作 dim=0
,遍历 dim=0
的层次,得到两个张量
y 1 = [ 1 2 ] y 2 = [ 3 4 ] y_1 = \begin{bmatrix} 1 \\ 2\end{bmatrix} \\ y_2 =\begin{bmatrix} 3 \\ 4\end{bmatrix} y1=[12]y2=[34]
第二重括号记作 dim=1
,遍历 dim=1
的层次,得到四个数字
y 11 = 1 y 12 = 2 y 21 = 3 y 22 = 4 y_{11} = 1 \\y_{12} =2 \\y_{21} = 3 \\ y_{22} = 4 y11=1y12=2y21=3y22=4
torch.repeat_interleave()
首先复制一下 PyTorch
官方文档的描述
torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) → Tensor
参数有三个
input
:待处理的 tensorrepeat
:每个元素重复几次dim
:在那个维度上重复看几个例子,以上一节中的 y y y 为例
dim=0
torch.repeat_interleave(y, 2, dim=0)
在 dim=0
,在 dim=0
,即所有张量上进行操作,每个一维 tensor 复制 1 次
tensor([[1, 2],
[1, 2],
[3, 4],
[3, 4]])
dim=1
torch.repeat_interleave(y, 3, dim=1)
在 dim=1
,即每个一维张量内部元素上进行操作,所有元素复制 2 次
tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
不提供 dim
参数
y.repeat_interleave(2)
将默认返回 y.Flatten()
tensor([1, 1, 2, 2, 3, 3, 4, 4])
repeat
参数本身也能是一个 tensor
torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
在 dim=0
,即所有张量上进行操作,第一个 tensor 复制 0 次,第二个复制 1 次
tensor([[1, 2],
[3, 4],
[3, 4]])
最后贴上官方文档的链接