torch.cat()函数

Concatenates the given sequence of tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.seq

torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk().

torch.cat() can be best understood via examples.

Parameters

  • tensors (sequence of Tensors) – any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.

  • dim (intoptional) – the dimension over which the tensors are concatenated

Keyword Arguments

out (Tensoroptional) – the output tensor.

Example:

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])

cat函数总结:

1.若tensor向量x是[C,H,W]三维,则:

1.dim=0,代表从通道上拼接各个tensor

2.dim=1,代表从列上拼接各个tensor

3.dim=2,代表从行上拼接各个tensor

2.而对于普通的tensor向量x[H,W],则:

1.dim=0,代表从列上拼接各个tensor

2.dim=1,代表从行上拼接各个tensor。

如下所示:

torch.cat()函数_第1张图片

tips:按照行或者列拼接的时候,必须相互对应,不能大小不同。否则会报错,如下所示:

torch.cat()函数_第2张图片

 

你可能感兴趣的:(pytorch,python,人工智能)