numpy和torch的repeat函数有大区别,小心混淆

首先是numpy的repeat函数

import numpy as np

# numpy的repeat有两个参数
# 参数1:代表重复次数
# 参数2:代表重复维度
x = np.array([[1,2]])
print(x.shape)  # (1, 2)
print(x.repeat(2,0).shape) #(2, 2)
print(x.repeat(3,1).shape) #(1, 6)

# 结果为:
# (1, 2)
# (2, 2)
# (1, 6)

然后是torch的repeat函数

import torch

# numpy的repeat参数量与本身维度有关
# 参数1:代表第一维重复次数
# 参数2:代表第二维重复次数
# 参数3:代表第三维重复次数
# 另外,repeat用四个参数操作一个三维张量时,会自动补维度,默认补到第一维
x = torch.tensor([[[1,2]],[[1,2]]])
print(x.shape)
print(x.repeat(1,1,1).shape)
print(x.repeat(3,2,1).shape)
print(x.repeat(3,2,1,1).shape)

# 结果为:
# torch.Size([2, 1, 2])
# torch.Size([2, 1, 2])
# torch.Size([6, 2, 2])
# torch.Size([3, 4, 1, 2])

你可能感兴趣的:(深度学习,pytorch,python,深度学习)