今天在看代码看到有用两组 torch.linspace().repeat() 来生成网格,不太理解。后来查了一查大概明白这个函数了,现在记下来加深点印象。
在pycharm中ctrl摁住再点击linspace()就可以查看到inspace()这个函数的全貌
def linspace(start: Number, end: Number, steps: _int=100, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
基本上只会用前面三个参数。
start:起始数字
end:末尾数字
steps:起始和末尾之间的点的个数
下面放个例子会清楚点 ->
t = torch.linspace(0,4,5)
print(t)
结果为
tensor([0., 1., 2., 3., 4.])
简单来说就是生成一个由等差数列构成的张量。
repeat()函数就是重复的意思,里面可以有许多参数。比如repeat(x,y,z)具体来说就是将原本的张量的行数变为原本的y倍,其列数变为原本的z倍,然后再其深度上变为原本的x倍。当然里面的参数可以远不止3个。简单来说就是在其维度上变为原来的几倍。再放一个例子->
a = torch.linspace(0,4,5).repeat(2,1)
print(a)
print(a.shape)
结果为
tensor([[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]])
torch.Size([2, 5])
--------------------------------------
b = torch.linspace(0,4,5).repeat(3,2,1)
print(b)
print(b.shape)
结果为
tensor([[[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]],
[[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]],
[[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]]])
torch.Size([3, 2, 5])
repeat()函数后面还可以跟上一些函数。比如t(),可以将张量转置,不过这个张量不能超过二维。三维的就不行了。