pytorch -- torch.linspace().repeat()用法详解

今天在看代码看到有用两组 torch.linspace().repeat() 来生成网格,不太理解。后来查了一查大概明白这个函数了,现在记下来加深点印象。

torch.linspace()

在pycharm中ctrl摁住再点击linspace()就可以查看到inspace()这个函数的全貌

def linspace(start: Number, end: Number, steps: _int=100, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...

基本上只会用前面三个参数。

start:起始数字

end:末尾数字

steps:起始和末尾之间的点的个数

下面放个例子会清楚点 ->

t = torch.linspace(0,4,5)
print(t)

结果为

tensor([0., 1., 2., 3., 4.])

简单来说就是生成一个由等差数列构成的张量。

repeat()

repeat()函数就是重复的意思,里面可以有许多参数。比如repeat(x,y,z)具体来说就是将原本的张量的行数变为原本的y倍,其列数变为原本的z倍,然后再其深度上变为原本的x倍。当然里面的参数可以远不止3个。简单来说就是在其维度上变为原来的几倍。再放一个例子->

a = torch.linspace(0,4,5).repeat(2,1)
print(a)
print(a.shape)

结果为

tensor([[0., 1., 2., 3., 4.],
        [0., 1., 2., 3., 4.]])
torch.Size([2, 5])

 --------------------------------------

b = torch.linspace(0,4,5).repeat(3,2,1)
print(b)
print(b.shape)

结果为

tensor([[[0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.]],

        [[0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.]],

        [[0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.]]])
torch.Size([3, 2, 5])

repeat()函数后面还可以跟上一些函数。比如t(),可以将张量转置,不过这个张量不能超过二维。三维的就不行了。

你可能感兴趣的:(学习笔记,python,pytorch)