pytorch每日一学29(torch.cat())在给定的维度上对tensor进行拼接

第29个方法

torch.cat(tensors, dim=0, *, out=None) → Tensor

此方法的意思是对给定的tensor在指定的维度上进行拼接,可以看做是torch.split()
torch.chunk()的逆向操作。

参数介绍:

  • tensors:要拼接的tensor,这里应该是多个tensor(此处只要是tensor就可以,包括复数tensor和量化tensor),例如如果将tensor a,b进行拼接,这里的输入应该是(a, b)。
  • dim:从哪个维度上对tensor进行拼接,0为第一个维度,1为第二个维度。
  • out:输出的tensor

注意: 对于进行拼接的两个张量,它们在除了拼接的维度上,其余维度上的形状应该相等。例如tensors=(a, b), dim=0那么a,b除了dim=1的维度,其余维度上的形状大小应该相等,不然会报错,因为无法进行拼接。

使用方法如下:

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])

当然此方法对复数tensor也适用:
pytorch每日一学29(torch.cat())在给定的维度上对tensor进行拼接_第1张图片
指定dim形状不一样时照样可以拼接:
pytorch每日一学29(torch.cat())在给定的维度上对tensor进行拼接_第2张图片
如果非指定维度形状不等就会报错:
pytorch每日一学29(torch.cat())在给定的维度上对tensor进行拼接_第3张图片

你可能感兴趣的:(pytorch每日一学,深度学习,pytorch,神经网络,机器学习,数据挖掘)