ptorch.cat理解

torch.cat

将两个张量进行拼接
主要是对dim取值的一个理解

a = torch.rand(2,3)
b = torch.rand(1,3)
c = torch.cat((a,b),dim=0)

当张量为二维的时候,dim=0表示,dim=1表示
因此上述代码中行可以不一样,但是列数必须相同,否则的话无法进行拼接。
a=([[0.3956, 0.4206, 0.3445],
[0.1258, 0.7269, 0.4539]])

b= ([[0.6137, 0.1521, 0.1562]])

c= ([[0.3956, 0.4206, 0.3445],
[0.1258, 0.7269, 0.4539],
[0.6137, 0.1521, 0.1562]])

当张量为三的时候,

import torch
a = torch.rand(2,2,3)
print("a=",a)
b = torch.rand(1,2,3)
print("b=",b)
c = torch.cat((a,b),dim=0)
print("c=",c)

a= tensor([[[0.3859, 0.4549, 0.3441],
[0.7041, 0.1164, 0.5377]],

[[0.2150, 0.9613, 0.4888],
[0.9834, 0.8159, 0.9237]]])

b= tensor([[[0.4117, 0.8209, 0.5537],
[0.1129, 0.0271, 0.0679]]])
c= tensor([[[0.3859, 0.4549, 0.3441],
[0.7041, 0.1164, 0.5377]],

[[0.2150, 0.9613, 0.4888],
[0.9834, 0.8159, 0.9237]],

[[0.4117, 0.8209, 0.5537],
[0.1129, 0.0271, 0.0679]]])

其中dim=0,表示的为batch,也就是上述索引rand=0,dim=1表示行,dim=2表示列,因此当dim表示多少的时候,只有当前索引位置的数值可以不同,其他地方的索引值必须相同。

你可能感兴趣的:(深度学习,pytorch)