二者都是用来扩展某维的数据的尺寸
一、expand()
返回当前张量在某维扩展更大后的张量。扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),只能扩展为1的维度:
a = torch.tensor([1,2,3,4])
print('扩展前a的shape:', a.shape)
a = a.expand(8, 4)
print('扩展后a的shape:', a.shape)
print(a)
输出:
扩展前a的shape: torch.Size([4])
扩展后a的shape: torch.Size([8, 4])
tensor([[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]])
二、repeat()
沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据:
a = torch.tensor([1,2,3,4])
print('repeat前a的shape:', a.shape)
a = a.repeat(3, 2)
print('repeat后a的shape:', a.shape)
print(a)
输出:
repeat前a的shape: torch.Size([4])
repeat后a的shape: torch.Size([3, 8])
tensor([[1, 2, 3, 4, 1, 2, 3, 4],
[1, 2, 3, 4, 1, 2, 3, 4],
[1, 2, 3, 4, 1, 2, 3, 4]])