pytorch torch.expand和torch.repeat的区别

expand
对矩阵中的单维度进行复制

import torch

x = torch.tensor([1, 2, 3, 4]) #x.shape=(1,4)   shape的大小和expand的大小一样

xnew = x.expand(2, 4)          #x.shape=(2,4)  shape的大小和expand的大小一样
#对第0维复制2次
#对于不需要扩张的维度填上原来的维度大小 比如上面例子的4
print(xnew)

结果

tensor([[1, 2, 3, 4],
        [1, 2, 3, 4]])

如果


x = torch.tensor([1, 2, 3, 4]) 
xnew = x.expand(2, 5)   
#将出现报错,说明只能对单维度的复制       

如果

import torch

x = torch.tensor([1, 2, 3, 4])
xnew = x.expand(3,2, 4)
print(xnew)

结果
tensor([[[1, 2, 3, 4],
         [1, 2, 3, 4]],

        [[1, 2, 3, 4],
         [1, 2, 3, 4]],

        [[1, 2, 3, 4],
         [1, 2, 3, 4]]])

  
  

repeat
对矩阵横向、纵向地复制

import torch

x = torch.tensor([1, 2, 3,4])
xnew = x.repeat(2,3)
print(xnew)

结果
tensor([[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4],
        [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]])


用repeat实现上面expand的结果

import torch

x = torch.tensor([1, 2, 3,4])
xnew = x.repeat(2,1)
print(xnew)

结果
tensor([[1, 2, 3, 4],
        [1, 2, 3, 4]])

你可能感兴趣的:(pytorch)