将两个张量进行拼接
主要是对dim取值的一个理解
a = torch.rand(2,3)
b = torch.rand(1,3)
c = torch.cat((a,b),dim=0)
当张量为二维的时候,dim=0表示行,dim=1表示列
因此上述代码中行可以不一样,但是列数必须相同,否则的话无法进行拼接。
a=([[0.3956, 0.4206, 0.3445],
[0.1258, 0.7269, 0.4539]])
b= ([[0.6137, 0.1521, 0.1562]])
c= ([[0.3956, 0.4206, 0.3445],
[0.1258, 0.7269, 0.4539],
[0.6137, 0.1521, 0.1562]])
当张量为三的时候,
import torch
a = torch.rand(2,2,3)
print("a=",a)
b = torch.rand(1,2,3)
print("b=",b)
c = torch.cat((a,b),dim=0)
print("c=",c)
a= tensor([[[0.3859, 0.4549, 0.3441],
[0.7041, 0.1164, 0.5377]],
[[0.2150, 0.9613, 0.4888],
[0.9834, 0.8159, 0.9237]]])
b= tensor([[[0.4117, 0.8209, 0.5537],
[0.1129, 0.0271, 0.0679]]])
c= tensor([[[0.3859, 0.4549, 0.3441],
[0.7041, 0.1164, 0.5377]],
[[0.2150, 0.9613, 0.4888],
[0.9834, 0.8159, 0.9237]],
[[0.4117, 0.8209, 0.5537],
[0.1129, 0.0271, 0.0679]]])
其中dim=0,表示的为batch,也就是上述索引rand=0,dim=1表示行,dim=2表示列,因此当dim表示多少的时候,只有当前索引位置的数值可以不同,其他地方的索引值必须相同。