pytorch中expand()和repeat()的区别

二者都是用来扩展某维的数据的尺寸

一、expand()
       返回当前张量在某维扩展更大后的张量。扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),只能扩展为1的维度:

a = torch.tensor([1,2,3,4])
print('扩展前a的shape:', a.shape)
a = a.expand(8, 4)
print('扩展后a的shape:', a.shape)
print(a)

输出:
扩展前a的shape: torch.Size([4])
扩展后a的shape: torch.Size([8, 4])
tensor([[1, 2, 3, 4],
        [1, 2, 3, 4],
        [1, 2, 3, 4],
        [1, 2, 3, 4],
        [1, 2, 3, 4],
        [1, 2, 3, 4],
        [1, 2, 3, 4],
        [1, 2, 3, 4]])

二、repeat()
       沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据:

a = torch.tensor([1,2,3,4])
print('repeat前a的shape:', a.shape)
a = a.repeat(3, 2)
print('repeat后a的shape:', a.shape)
print(a)

输出:
repeat前a的shape: torch.Size([4])
repeat后a的shape: torch.Size([3, 8])
tensor([[1, 2, 3, 4, 1, 2, 3, 4],
        [1, 2, 3, 4, 1, 2, 3, 4],
        [1, 2, 3, 4, 1, 2, 3, 4]])

你可能感兴趣的:(pytorch,python,pytorch)