pytorch:repeat()和expand()的区别

文章目录

    • repeat()和expand()的区别
      • expand()
      • repeat()
      • repeat()和expand()的区别
    • 参考

repeat()和expand()的区别

expand()

expand张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view)

import torch

x = torch.tensor([1, 2, 3])
x = x.expand(2, 3)
print(x.shape)

'''
torch.Size([2, 3])
'''

expand()类似广播机制,我们将expand看成是torch.ones(),上述代码中x的原始shape为[3],当x.expand(2, 3)时,可以看成是x*torch.ones(2, 3),以shape的视角来看即是:
[ 3 ] ∗ [ 2 , 3 ] ↪ [ 1 , 3 ] ∗ [ 2 , 3 ] ⇝ [ 2 , 3 ] ∗ [ 2 , 3 ] [3]*[2, 3] \hookrightarrow [1,3]*[2,3] \rightsquigarrow [2,3]*[2,3] [3][2,3][1,3][2,3][2,3][2,3]

repeat()

沿着特定的维度重复这个张量,同时这个函数拷贝张量的数据

import torch

x = torch.tensor([1, 2, 3])
x = x.repeat(3, 2)
print(x)

'''
tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]])
'''

repeat()和expand()的区别

import torch

a = torch.ones(3, 1).float()
b = a.expand(3, 4)
c = a.repeat(1, 4)

'''
data_ptr(): Returns the address of the first element of self tensor.
'''
print(a.data_ptr(), b.data_ptr(), c.data_ptr())
a[0, 0] = 3
print(f"b: {b}")
print(f"c: {c}")

'''
2151521637440 2151521637440 2151543296960
tensor([[3., 3., 3., 3.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
'''

expand()在原内存上返回一个新的视图,所以改变了a的值b也能看到不同,repeat是开辟新内存

参考

1、PyTorch学习笔记——repeat()和expand()区别

你可能感兴趣的:(python,pytorch,pytorch,python)