pytorch : expand 和 repeat 函数

expand 函数

expand(*sizes) -> Tensor
*sizes(torch.Size or int) - the desired expanded size
Returns a new view of the self tensor with singleton dimensions expanded to a larger size.

expand用于扩展tensor数据。但有以下注意点:

  1. 该函数不复制数据
  2. 扩展时只在能度数是1的维度上扩展
  3. 生成的对象与原对象共享内存
import torch

a = torch.tensor([1,2])
b = a.expand(2,-1) # -1 代表此维度不变
print('b : ', b)
print('\nafter modifying b')
b[0][0]=10
print('a : ', a)
print('b : ', b)
    b :  tensor([[1, 2],
            [1, 2]])

    after modifying b
    a :  tensor([10,  2])
    b :  tensor([[10,  2],
            [10,  2]])
c = a.expand(2,4) # tesor 'a'最后一维维度是2,所以扩展时出错
RuntimeError: The expanded size of the tensor (4) must match the existing size (2) at non-singleton dimension 1.  Target sizes: [2, 4].  Tensor sizes: [2]

repeat 函数

repeat(*sizes) -> Tensor
*size(torch.Size or int) - The number of times to repeat this tensor along each dimension.
Repeats this tensor along the specified dimensions.

返回tensor在某个维度上扩展后的张量.注意:

  1. 此函数会生成新的数据变量,和原tensor不共享内存
d = a.repeat(2,2)
print('d: ', d)
d[0][0] = 10
print('\nafter modifying d ')
print('a: ', a)
print('d: ', d)

a.repeat(2,4) # 此参数 expand() 函数不通过
d:  tensor([[1, 2, 1, 2],
        [1, 2, 1, 2]])
        
after modifying d 
a:  tensor([1, 2])
d:  tensor([[10,  2,  1,  2],
        [ 1,  2,  1,  2]])

tensor([[1, 2, 1, 2, 1, 2, 1, 2],
        [1, 2, 1, 2, 1, 2, 1, 2]])

你可能感兴趣的:(pytorch,pytorch)