PyTorch 函数解释:expand、repeat

torch.Tensor有两个实例方法可以用来扩展某维的数据的尺寸,分别是 repeat()expand()

expand()

返回当前张量在某维扩展更大后的张量。按照指定size扩充。

扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),一个大小(size)等于1的维度扩展到更大的尺寸。
代码示例:

In [45]: x = torch.randn(1,3)

In [46]: x
Out[46]: tensor([[-1.1352,  0.3773, -0.2824]])

In [47]: x.expand(2, 3)
Out[47]:
tensor([[-1.1352,  0.3773, -0.2824],
        [-1.1352,  0.3773, -0.2824]])

In [48]: x.expand(2, -1)
Out[48]:
tensor([[-1.1352,  0.3773, -0.2824],
        [-1.1352,  0.3773, -0.2824]])

repeat()

沿着特定的维度重复这个张量,按照倍数扩充;和expand()不同的是,这个函数拷贝张量的数据。

In [53]: x
Out[53]: tensor([[-1.1352,  0.3773, -0.2824]])

In [54]: x.shape
Out[54]: torch.Size([1, 3])

In [55]: x.repeat(2,3)
Out[55]:
tensor([[-1.1352,  0.3773, -0.2824, -1.1352,  0.3773, -0.2824, -1.1352,  0.3773,
         -0.2824],
        [-1.1352,  0.3773, -0.2824, -1.1352,  0.3773, -0.2824, -1.1352,  0.3773,
         -0.2824]])

In [56]: x.repeat(2,3).shape
Out[56]: torch.Size([2, 9])

你可能感兴趣的:(深度学习,pytorch)