原文链接请参考:PyTorch 常用函数解析 | 梦家博客
对张量沿着某一维度进行拼接。连接后数据的总维数不变。,ps:能拼接的前提是对应的维度相同!!!
例如对两个2维tensor(分别为2*3
,1*3
)进行拼接,拼接完后变为3*3
的2维 tensor。
In [1]: import torch
In [2]: torch.manual_seed(1)
Out[2]: <torch._C.Generator at 0x19e56f02e50>
In [3]: x = torch.randn(2,3)
In [4]: y = torch.randn(1,3)
In [5]: x
Out[5]:
tensor([[ 0.6614, 0.2669, 0.0617],
[ 0.6213, -0.4519, -0.1661]])
In [6]: y
Out[6]: tensor([[-1.5228, 0.3817, -1.0276]])
In [9]: torch.cat((x,y),0)
Out[9]:
tensor([[ 0.6614, 0.2669, 0.0617],
[ 0.6213, -0.4519, -0.1661],
[-1.5228, 0.3817, -1.0276]])
以上dim=0
表示按列进行拼接,dim=1
表示按行进行拼接。
代码如下:
In [11]: z = torch.randn(2,2)
In [12]: z
Out[12]:
tensor([[-0.5631, -0.8923],
[-0.0583, -0.1955]])
In [13]: x
Out[13]:
tensor([[ 0.6614, 0.2669, 0.0617],
[ 0.6213, -0.4519, -0.1661]])
In [14]: torch.cat((x,z),1)
Out[14]:
tensor([[ 0.6614, 0.2669, 0.0617, -0.5631, -0.8923],
[ 0.6213, -0.4519, -0.1661, -0.0583, -0.1955]])
torch.cat()
拼接不会增加新的维度,但torch.stack()
则会增加新的维度。
例如对两个1*2
维的 tensor 在第0个维度上stack,则会变为2*1*2
的 tensor;在第1个维度上stack,则会变为1*2*2
的tensor。
In [22]: x = torch.randn(1,2)
In [23]: y = torch.randn(1,2)
In [24]: x.shape
Out[24]: torch.Size([1, 2])
In [25]: x = torch.randn(1,2)
In [26]: y = torch.randn(1,2)
In [27]: torch.stack((x,y),0) # 维度0堆叠
Out[27]:
tensor([[[-1.8313, 1.5987]],
[[-1.2770, 0.3255]]])
In [28]: torch.stack((x,y),0).shape
Out[28]: torch.Size([2, 1, 2])
In [29]: torch.stack((x,y),1) # 维度1堆叠
Out[29]:
tensor([[[-1.8313, 1.5987],
[-1.2770, 0.3255]]])
In [30]: torch.stack((x,y),1).shape
Out[30]: torch.Size([1, 2, 2])
举例说明
torch.manual_seed(1)
x = torch.randn(2,3)
print(x)
原来x的结果:
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]
将x的维度互换:x.transpose(0,1)
,其实相当于转置操作!
结果
0.6614 0.6213
0.2669 -0.4519
0.0617 -0.1661
[torch.FloatTensor of size 3x2]
permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。
In [31]: x = torch.randn(2,3,4)
In [32]: x
Out[32]:
tensor([[[ 0.7626, 0.4415, 1.1651, 2.0154],
[ 0.2152, -0.5242, -1.8034, -1.3083],
[ 0.4100, 0.4085, 0.2579, 1.0950]],
[[-0.5065, 0.0998, -0.6540, 0.7317],
[-1.4567, 1.6089, 0.0938, -1.2597],
[ 0.2546, -0.5020, -1.0412, 0.7323]]])
In [33]: x.shape
Out[33]: torch.Size([2, 3, 4])
In [34]: x.permute(1,0,2) # 0维和1维互换,2维不变!
Out[34]:
tensor([[[ 0.7626, 0.4415, 1.1651, 2.0154],
[-0.5065, 0.0998, -0.6540, 0.7317]],
[[ 0.2152, -0.5242, -1.8034, -1.3083],
[-1.4567, 1.6089, 0.0938, -1.2597]],
[[ 0.4100, 0.4085, 0.2579, 1.0950],
[ 0.2546, -0.5020, -1.0412, 0.7323]]])
In [35]: x.permute(1,0,2).shape
Out[35]: torch.Size([3, 2, 4])
常用来增加或减少维度,如没有batch维度时,增加batch维度为1。
In [38]: x = torch.randn(1,3,4)
In [39]: x.shape
Out[39]: torch.Size([1, 3, 4])
In [40]: x
Out[40]:
tensor([[[-0.4791, 0.2912, -0.8317, -0.5525],
[ 0.6355, -0.3968, -0.6571, -1.6428],
[ 0.9803, -0.0421, -0.8206, 0.3133]]])
In [41]: x.squeeze()
Out[41]:
tensor([[-0.4791, 0.2912, -0.8317, -0.5525],
[ 0.6355, -0.3968, -0.6571, -1.6428],
[ 0.9803, -0.0421, -0.8206, 0.3133]])
In [42]: x.squeeze().shape
Out[42]: torch.Size([3, 4])
In [43]: x.unsqueeze(0)
Out[43]:
tensor([[[[-0.4791, 0.2912, -0.8317, -0.5525],
[ 0.6355, -0.3968, -0.6571, -1.6428],
[ 0.9803, -0.0421, -0.8206, 0.3133]]]])
In [44]: x.unsqueeze(0).shape
Out[44]: torch.Size([1, 1, 3, 4])