最近在学习沐神的d2l的时候,深受其中代码的折磨,有些函数真的是从来没见过,组合起来更是让人头皮发麻,根本看不懂代码在写些什么。
写这篇文章,主要是为了总结一下Python当中的repeat()函数和repeat_interleave()函数,这两个函数在应用于Pytorch和Numpy数组的时候得到的结果也是不一样的,所以有很大的槽点需要注意!
首先是总结应用于Pytorch领域的repeat()函数和repeat_interleave()函数:
1.repeat()
话不多说,直接上代码:
import torch
original_tensor = torch.tensor([[1, 2], [3, 4]])
repeated_tensor = original_tensor.repeat(2, 3)
print(repeated_tensor)
输出为:
tensor([[1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4],
[1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4]])
不难从输出当中得出结论:.repeat(2, 3)就是沿着第一个维度(行)重复 2 次,沿着第二个维度(列)重复 3 次,最终生成了一个 4x6 的张量。注意repeat是一组元素一组元素地重复,这与下面的repeat_interleave()函数是不相同的。
2.repeat_interleave()
该函数与repeat()函数的区别在于,它是沿着指定的维度复制张量元素
①不指定dim,重复次数为2次,表示将把给定的输入张量展平(flatten)为向量,然后将每个元素重复2次,并返回重复后的张量。
a = torch.randn(3,2)
a,a.repeat_interleave(2)
输出为:
(tensor([[-1.03, -0.32],
[ 0.43, 0.78],
[ 0.91, -0.11]]),
tensor([-1.03, -1.03, -0.32, -0.32, 0.43, 0.43, 0.78, 0.78, 0.91, 0.91,
-0.11, -0.11]))
②输入二维张量,指定dim=0,重复次数为3次,表示把输入张量每行元素重复3次
a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=0)
输出为:
(tensor([[ 0.14, 1.47],
[-1.52, -0.62],
[-0.24, -0.27]]),
tensor([[ 0.14, 1.47],
[ 0.14, 1.47],
[ 0.14, 1.47],
[-1.52, -0.62],
[-1.52, -0.62],
[-1.52, -0.62],
[-0.24, -0.27],
[-0.24, -0.27],
[-0.24, -0.27]]))
③输入二维张量,指定dim=1,重复次数为3次,表示把输入张量每列元素重复3次
a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=1)
输出为:
(tensor([[-0.81, 0.56],
[-2.41, -0.56],
[ 0.38, -0.90]]),
tensor([[-0.81, -0.81, -0.81, 0.56, 0.56, 0.56],
[-2.41, -2.41, -2.41, -0.56, -0.56, -0.56],
[ 0.38, 0.38, 0.38, -0.90, -0.90, -0.90]]))
④输入二维张量,指定dim=0,重复次数为一个张量列表[n1,n2,n3],表示在(dim=0)对应行上面重复n1,n2,n3遍,张量列表的长度必须与dim=0的维度的长度一样,否则会报错
a = torch.randn(3,2)
a,torch.repeat_interleave(a,torch.tensor([2,3,4]),dim=0)
输出为:
(tensor([[-0.79, 0.54],
[-0.47, -0.25],
[-0.13, 1.03]]),
tensor([[-0.79, 0.54],
[-0.79, 0.54],
[-0.47, -0.25],
[-0.47, -0.25],
[-0.47, -0.25],
[-0.13, 1.03],
[-0.13, 1.03],
[-0.13, 1.03],
[-0.13, 1.03]]))
总结:可以看出,两个函数方法最大的区别就是repeat_interleave是一个元素一个元素地重复,而repeat是一组元素一组元素地重复
那到这里就完了吗?完全没有!经过测试发现,以上都是repeat()函数和repeat_interleave()函数应用于pytorch的tensor张量,但当它们应用于numpy数组时,结果又是不一样的!
例如:
test_array = torch.arange(9).reshape(3, 3)
print('采用torch tensor原始:\n', test_array)
print('采用torch tensor的repeat函数:\n', test_array.repeat(2, 1))
print('采用torch tensor的repeat_interleave函数:\n', test_array.repeat_interleave(2, dim=0))
test_array2 = np.arange(9).reshape(3, 3)
print('采用numpy array原始:\n', test_array2)
print('采用numpy array的repeat函数:\n', test_array2.repeat(2, 1))
print('采用numpy array的repeat_interleave函数:\n', test_array2.repeat_interleave(2, dim=0))
我们运行上述代码,看看结果怎么样:
采用torch tensor原始:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
采用torch tensor的repeat函数:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
采用torch tensor的repeat_interleave函数:
tensor([[0, 1, 2],
[0, 1, 2],
[3, 4, 5],
[3, 4, 5],
[6, 7, 8],
[6, 7, 8]])
采用numpy array原始:
[[0 1 2]
[3 4 5]
[6 7 8]]
采用numpy array的repeat函数:
[[0 0 1 1 2 2]
[3 3 4 4 5 5]
[6 6 7 7 8 8]]
Traceback (most recent call last):
File "D:/PythonProject/DiveIntoDeepLearning(LiMu)/main.py", line 82, in <module>
print('采用numpy array的repeat_interleave函数:\n', test_array2.repeat_interleave(2, dim=0))
AttributeError: 'numpy.ndarray' object has no attribute 'repeat_interleave'
Process finished with exit code 1
从输出结果可以得出以下结论:
①pytorch当中的numpy.repeat(2, 1)是指在第一个维度(行)上复制两次,在第二个维度(列)上复制1次,并且是一组元素一组元素地复制;Numpy当中的.repeat(2, 1)是指在第二个维度上(列,对应dim值为1)复制两次,并且是一个一个元素的复制
②numpy没有repeat_interleave函数