一. torch.repeat_interleave()函数解析
1.函数说明
官网:torch.repeat_interleave(),函数说明如下图所示:
2. 函数原型
torch.repeat_interleave(input, repeats, dim=None) → Tensor
3. 函数功能
沿着指定的维度重复张量的元素
4. 输入参数:
1)input (类型:torch.Tensor):输入张量
2)repeats(类型:int或torch.Tensor):每个元素的重复次数
3)dim(类型:int)需要重复的维度。默认情况下dim=None,表示将把给定的输入张量展平(flatten)为向量,然后将每个元素重复repeats次,并返回重复后的张量。
5. 注意
1)如果不指定dim,则默认将输入张量扁平化(维数是1,因此这时repeats必须是一个数,不能是数组),并且返回一个扁平化的输出数组。
2)返回的数组与输入数组维数相同,并且除了给定的维度dim,其他维度大小与输入数组相应维度大小相同
3)repeats:如果传入数组,则必须是tensor格式。并且只能是一维数组,数组长度与输入数组input的dim维度大小相同
6. 代码例子
6.1 输入一维张量,不指定dim,重复次数为2次,表示将把给定的输入张量展平(flatten)为向量,然后将每个元素重复2次,并返回重复后的张量。
a = torch.randn(5)
a,torch.repeat_interleave(a,2)
输出结果如下所示:
(tensor([ 0.4030, -1.1536, -2.4513, 1.1454, -0.8818]),
tensor([ 0.4030, 0.4030, -1.1536, -1.1536, -2.4513, -2.4513, 1.1454, 1.1454,
-0.8818, -0.8818]))
6.2 输入二维张量,不指定dim,重复次数为2次,表示将把给定的输入张量展平(flatten)为向量,然后将每个元素重复2次,并返回重复后的张量。
a = torch.randn(3,2)
a,a.repeat_interleave(2)
输出结果如下:
(tensor([[-1.03, -0.32],
[ 0.43, 0.78],
[ 0.91, -0.11]]),
tensor([-1.03, -1.03, -0.32, -0.32, 0.43, 0.43, 0.78, 0.78, 0.91, 0.91,
-0.11, -0.11]))
6.3 输入二维张量,指定dim=0,重复次数为3次,表示把输入张量每行元素重复3次
a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=0)
输出结果如下:
(tensor([[ 0.14, 1.47],
[-1.52, -0.62],
[-0.24, -0.27]]),
tensor([[ 0.14, 1.47],
[ 0.14, 1.47],
[ 0.14, 1.47],
[-1.52, -0.62],
[-1.52, -0.62],
[-1.52, -0.62],
[-0.24, -0.27],
[-0.24, -0.27],
[-0.24, -0.27]]))
6.4 输入二维张量,指定dim=1,重复次数为3次,表示把输入张量每列元素重复3次
a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=1)
输出结果如下:
(tensor([[-0.81, 0.56],
[-2.41, -0.56],
[ 0.38, -0.90]]),
tensor([[-0.81, -0.81, -0.81, 0.56, 0.56, 0.56],
[-2.41, -2.41, -2.41, -0.56, -0.56, -0.56],
[ 0.38, 0.38, 0.38, -0.90, -0.90, -0.90]]))
6.5 输入二维张量,指定dim=0,重复次数为一个张量列表[n1,n2,n3],表示在(dim=0)对应行上面重复n1,n2,n3遍,张量列表的长度必须与dim=0的维度的长度一样,否则会报错
a = torch.randn(3,2)
a,torch.repeat_interleave(a,torch.tensor([2,3,4]),dim=0)#表示第一行重复2遍,第二行重复3遍,第三行重复4遍
输出结果如下:
(tensor([[-0.79, 0.54],
[-0.47, -0.25],
[-0.13, 1.03]]),
tensor([[-0.79, 0.54],
[-0.79, 0.54],
[-0.47, -0.25],
[-0.47, -0.25],
[-0.47, -0.25],
[-0.13, 1.03],
[-0.13, 1.03],
[-0.13, 1.03],
[-0.13, 1.03]]))
7. 与torch.repeat()函数区别:
两个函数方法最大的区别就是repeat_interleave是一个元素一个元素地重复,而repeat是一组元素一组元素地重复.
参考知识文章