libtorch之cv::Mat和Tensor的互转

      cv::Mat转Tensro libtorch官网例子中给的是如下形式:

        torch::TensorOptions option(torch::kFloat);   

        auto img_tensor = torch::from_blob(img.data,

                                                             {1,img.rows,img.cols,img.channels() }, option);

                                                            // opencv (H x W x C)  to torch  (batch x H x W x C)

        img_tensor = img_tensor.permute({0, 3, 1, 2 });//调整矩阵列顺序 batch x C x H x W

     这种方式应用在分类问题中应该没什么问题,但应用在faster RCNN时却存在一定问题。使用from_blob构造一个Tensor,他只是一个内存数据的拷贝,没有更多的对数据处理。

      cv::Mat数据在内存中是如下形式存储的

     (r,g,b),(r,g,b),(r,g,b)......   

     然而Tensor在内存中的数据确是

      (r,r,r,r,......),(g,g,g,g,......),(b,b,b,......)

     这两种不同结构的数据是无法直接使用内存拷贝方式转换。

     以下是一个完整的cv::Mat转Tensor

    torch::Tensor image2Tensor(cv::Mat &img)

   {

          if(img.channels() == 1)

         {

             torch::TensorOptions option(torch::kFloat);

             auto img_tensor = torch::from_blob(img.data, {1,img.rows,img.cols}, option);

                                                            // opencv (H x W)  to torch  (C x H x W)

              return img_tensor;

        }

       else

       {

                torch::Tensor t = torch::zeros({3,img.rows,img.cols},torch::kFloat);

               float * r = (float*)t[0].data_ptr();

               float * g = (float*)t[1].data_ptr();

               float * b = (float*)t[2].data_ptr();

              for(int row = 0; row < img.rows; row++)

             {

                     for(int col = 0; col < img.cols; col++)

                    {

                            cv::Vec3f &v = img.at(row,col);

                            r[row*img.cols + col] = v[0];

                            g[row*img.cols + col] = v[1];

                            b[row*img.cols + col] = v[2];

                   }

           }

            return t;

        }

         return torch::Tensor();

    }

     这是一个Tensor转cv::Mat的函数,方便可视化的呈现特征图像。

     cv::Mat tensor2Image(const torch::Tensor &t)

    {

             torch::IntArrayRef s = t.sizes();

             if(s[0] == 1)

            {

                    cv::Mat des(s[1], s[2], CV_32FC1, t.data_ptr());

                      return des;

             }

            else

           {

                 float * r = (float*)t[0].data_ptr();

                 float * g = (float*)t[1].data_ptr();

                 float * b = (float*)t[2].data_ptr();

                 int rows = s[1];

                 int cols = s[2];

                 cv::Mat des = cv::Mat::zeros(rows, cols,CV_32FC3);

                for(int row=0; row

                {

                        for(int col=0; col

                       {

                            cv::Vec3f v;

                            v[0] = r[row*cols + col];

                            v[1] = g[row*cols + col];

                            v[2] = b[row*cols + col];

                           des.at(row,col) = v;

                       }

                }

               return des;

          }

          return cv::Mat();

      }

你可能感兴趣的:(深度学习,c++)