Pytorch基础 - 4. torch.expand() 和 torch.repeat()

目录

1.  torch.expand(*sizes)

2. torch.repeat(*sizes)

3. 两者内存占用的区别


在PyTorch中有两个函数可以用来扩展某一维度的张量,即 torch.expand() 和 torch.repeat()

1.  torch.expand(*sizes)

【含义】将输入张量在大小为1的维度上进行拓展,并返回扩展更大后的张量

【参数】sizes的shape为torch.Size 或 int,指拓展后的维度, 当值为-1的时候,表示维度不变

import torch

if __name__ == '__main__':
    x = torch.rand(1, 3)
    y1 = x.expand(4, 3)
    print(y1.shape)  # torch.Size([4, 3])
    y2 = x.expand(6, -1)
    print(y2.shape)  # torch.Size([6, 3])

2. torch.repeat(*sizes)

【含义】沿着特定维度扩展张量,并返回扩展后的张量

【参数】sizes的shape为torch.Size 或 int,指对当前维度扩展的倍数

import torch

if __name__ == '__main__':
    x = torch.rand(2, 3)
    y1 = x.repeat(4, 2)
    print(y1.shape)  # torch.Size([8, 6])

3. 两者内存占用的区别

torch.expand 不会占用额外空间,只是在存在的张量上创建一个新的视图

torch.repeat 和 torch.expand 不同,它是拷贝了数据,会占用额外的空间

示例如下:

import torch

if __name__ == '__main__':
    x = torch.rand(1, 3)
    y1 = x.expand(4, 3)
    y2 = x.repeat(2, 3)
    print(x.data_ptr(), y1.data_ptr())  # 52364352 52364352
    print(x.data_ptr(), y2.data_ptr())  # 52364352 8852096

你可能感兴趣的:(#,Pytorch操作,pytorch,深度学习,人工智能,expand,repeat)