libtorch中cat函数的使用

 在使用libtorch中经常用到vector和cat使用的情况,在此写了几个例子;cat函数一般有三种使用方式,分别如下:

    auto tensor1 = torch::randn({1, 3, 4, 4});
    auto tensor2 = torch::randn({1, 3, 4, 4});

    //method 1
    auto cattensors = torch::cat({tensor1, tensor2});
    cout << cattensors.sizes() << endl;
    //method 2
    vector tensor_vec;
    tensor_vec.push_back(tensor1);
    tensor_vec.push_back(tensor2);
    torch::TensorList tensorlist{tensor_vec};
    cattensors = torch::cat(tensorlist);
    cout << cattensors.sizes() << endl;
    //method 3
    vector tensor_vec2;
    tensor1 = tensor1.permute({0, 3, 1, 2}).contiguous();
    tensor_vec2.push_back(tensor1);
    tensor2 = tensor2.permute({0, 3, 1, 2}).contiguous();
    tensor_vec2.push_back(tensor2);
    auto cattensors2 = torch::cat(tensor_vec2);
    cout << cattensors2.sizes() << endl;

需要注意的是,cat拼接tensor时必须时连续的tensor

你可能感兴趣的:(C/C++,Pytorch,人工智能,深度学习,计算机视觉)