torch.cat(tensors, dim=0, *, out=None) → Tensor
此方法的意思是对给定的tensor在指定的维度上进行拼接,可以看做是torch.split()
和
torch.chunk()的逆向操作。
参数介绍:
tensors
:要拼接的tensor,这里应该是多个tensor(此处只要是tensor就可以,包括复数tensor和量化tensor),例如如果将tensor a,b进行拼接,这里的输入应该是(a, b)。dim
:从哪个维度上对tensor进行拼接,0为第一个维度,1为第二个维度。out
:输出的tensor注意: 对于进行拼接的两个张量,它们在除了拼接的维度上,其余维度上的形状应该相等。例如tensors=(a, b), dim=0那么a,b除了dim=1的维度,其余维度上的形状大小应该相等,不然会报错,因为无法进行拼接。
使用方法如下:
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,
-1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,
-0.5790, 0.1497]])