Pytorch 中 torch.cat() 函数解析

Pytorch 中主要有两个拼接函数:

  1. torch.stack()

  1. torch.cat()

这里主要介绍 torch.cat()

1 函数作用

  • 对给定维度上的输入的 Tensor 序列进行连接

  • torch.cat() 和python中的内置函数cat(), 在使用和目的上,是没有区别的,区别在于前者操作对象是tensor

2 参数解析

import torch

outputs = torch.cat(inputs, dim) -> Tensor
  • inputs : 待连接的张量, 必须是 Tensor, 注意连接多个 Tensor 时, 需要将多个 Tensor 放入一个 list[] 中

  • dim : 从那个维度进行连接, 必须小于维度的个数

3 示例

import torch

a = torch.tensor([[1, 1, 1], [2, 2, 2]])
b = torch.tensor([[3, 3, 3], [4, 4, 4]])
print("dim = 0 :", torch.cat([a, b], dim = 0))
print("dim = 1 :", torch.cat([a, b], dim = 1))

>>> dim = 0 : tensor([[1, 1, 1],
                    [2, 2, 2],
                    [3, 3, 3],
                    [4, 4, 4]])
>>> dim = 1 : tensor([[1, 1, 1, 3, 3, 3],
                    [2, 2, 2, 4, 4, 4]])

你可能感兴趣的:(Pytorch,中的各种函数,Pytorch)