expand张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view)
import torch
x = torch.tensor([1, 2, 3])
x = x.expand(2, 3)
print(x.shape)
'''
torch.Size([2, 3])
'''
expand()类似广播机制,我们将expand看成是torch.ones()
,上述代码中x的原始shape为[3],当x.expand(2, 3)时,可以看成是x*torch.ones(2, 3)
,以shape的视角来看即是:
[ 3 ] ∗ [ 2 , 3 ] ↪ [ 1 , 3 ] ∗ [ 2 , 3 ] ⇝ [ 2 , 3 ] ∗ [ 2 , 3 ] [3]*[2, 3] \hookrightarrow [1,3]*[2,3] \rightsquigarrow [2,3]*[2,3] [3]∗[2,3]↪[1,3]∗[2,3]⇝[2,3]∗[2,3]
沿着特定的维度重复这个张量,同时这个函数拷贝张量的数据
import torch
x = torch.tensor([1, 2, 3])
x = x.repeat(3, 2)
print(x)
'''
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])
'''
import torch
a = torch.ones(3, 1).float()
b = a.expand(3, 4)
c = a.repeat(1, 4)
'''
data_ptr(): Returns the address of the first element of self tensor.
'''
print(a.data_ptr(), b.data_ptr(), c.data_ptr())
a[0, 0] = 3
print(f"b: {b}")
print(f"c: {c}")
'''
2151521637440 2151521637440 2151543296960
tensor([[3., 3., 3., 3.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
'''
expand()在原内存上返回一个新的视图,所以改变了a的值b也能看到不同,repeat是开辟新内存
1、PyTorch学习笔记——repeat()和expand()区别