pytorch中torch.cat() 和paddle中的paddle.concat()函数用法

在pytorch中:

torch.cat(x=[A,B],dim=n)函数:是将数据A,B沿着dim=n的方向进行拼接;x必须是list或者tuple类型的tensor.

例子:

import torch
x1 = torch.tensor([[1, 2, 3],
                    [4, 5, 6]])

x2 = torch.tensor([[11, 12, 13],
                    [14, 15, 16]])

x3 = torch.tensor([[21, 22],
                    [23, 24]])

out1=torch.cat([x1,x2,x3],dim=-1) #(dim=-1:横着拼接)
out2=torch.cat([x1,x2,x3],dim=1) #(dim=1:横着拼接)
out3=torch.cat([x1,x2],dim=0) #(dim=0:竖着拼接)


print(out1.shape)
print("out1:\n",out1)
print("out2:\n",out2)
print("out3:\n",out3)


========================================
输出结果
torch.Size([2, 8])
out1:  #(dim=-1:横着拼接)
 tensor([[ 1,  2,  3, 11, 12, 13, 21, 22],
        [ 4,  5,  6, 14, 15, 16, 23, 24]])
out2:  #(dim=1:横着拼接)
 tensor([[ 1,  2,  3, 11, 12, 13, 21, 22],
        [ 4,  5,  6, 14, 15, 16, 23, 24]])
out3:  #(dim=0:竖着拼接)
 tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [11, 12, 13],
        [14, 15, 16]])

注意:竖着拼接的两个数据的维度需要相同;dim=-1 和dim=1的效果是相同的。

在paddle中:

paddle.concat(xaxis=0name=None):对输入沿参数 axis 轴进行拼接,返回一个新的 Tensor。

参数:x一般是list或者tuple类型的数据

  • x (list|tuple) - 待联结的 Tensor list 或者 Tensor tuple支持的数据类型为:bool、float16、float32、float64、int32、int64、uint8, x 中所有 Tensor 的数据类型应该一致。

  • axis (int|Tensor,可选) - 指定对输入 x 进行运算的轴,可以是整数或者形状为[1]的 Tensor,数据类型为 int32 或者 int64。 axis 的有效范围是 [-R, R),R 是输入 x 中 Tensor 的维度,axis 为负值时与 axis+R 等价。默认值为 0。

  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。

例子:

import paddle

x1 = paddle.to_tensor([[1, 2, 3],
                       [4, 5, 6]])
x2 = paddle.to_tensor([[11, 12, 13],
                       [14, 15, 16]])
x3 = paddle.to_tensor([[21, 22],
                       [23, 24]])
zero = paddle.full(shape=[1], dtype='int32', fill_value=0)


out1 = paddle.concat(x=[x1, x2, x3], axis=-1) #(-1:横着拼接) -1和 1都是横着拼接
out2 = paddle.concat(x=[x1, x2], axis=0) # (0 :竖着拼接:)
out3 = paddle.concat(x=[x1, x2], axis=zero)

# out1
# [[ 1  2  3 11 12 13 21 22]    #(-1:横着拼接)
#  [ 4  5  6 14 15 16 23 24]]   
# out2 out3     # (0 :竖着拼接:)
# [[ 1  2  3]
#  [ 4  5  6]
#  [11 12 13]
#  [14 15 16]]

你可能感兴趣的:(深度学习,pytorch,paddle,深度学习)