RealBasicVSR模型转成ONNX以及用c++推理

文章目录

    • 安装RealBasicVSR的环境
      • 1. 新建一个conda环境
      • 2. 安装pytorch(官网上选择合适的版本)版本太低会有问题
      • 3. 安装 mim 和 mmcv-full
      • 4. 安装 mmedit
    • 下载RealBasicVSR源码
    • 下载模型文件
    • 写一个模型转换的脚步
    • 测试生成的模型
    • 用ONNX Runtime c++推理
    • 效果:

安装RealBasicVSR的环境

1. 新建一个conda环境

conda create -n RealBasicVSR_to_ONNX  python=3.8 -y
conda activate RealBasicVSR_to_ONNX

2. 安装pytorch(官网上选择合适的版本)版本太低会有问题

conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge

3. 安装 mim 和 mmcv-full

pip install openmim
mim install mmcv-full

4. 安装 mmedit

pip install mmedit

下载RealBasicVSR源码

git clone https://github.com/ckkelvinchan/RealBasicVSR.git

下载模型文件

模型文件下载 (Dropbox / Google Drive / OneDrive) ,随便选一个渠道下载就行

cd RealBasicVSR
#然后新建文件夹model
将模型文件放在model文件夹下

写一个模型转换的脚步

import cv2
import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
from mmedit.core import tensor2img

from realbasicvsr.models.builder import build_model

def init_model(config, checkpoint=None):
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
    config.model.pretrained = None
    config.test_cfg.metrics = None
    model = build_model(config.model, test_cfg=config.test_cfg)
    if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint)

    model.cfg = config  # save the config in the model for convenience
    model.eval()
    return model


def main():
    model = init_model("./configs/realbasicvsr_x4.py","./model/RealBasicVSR_x4.pth")
    src = cv2.imread("./data/img/test1.png")
    src = torch.from_numpy(src / 255.).permute(2, 0, 1).float()
    src = src.unsqueeze(0)
    input_arg = torch.stack([src], dim=1)

    torch.onnx.export(model,
        input_arg,
        'realbasicvsr.onnx',
        training= True,
        input_names= ['input'],
        output_names=['output'],
        opset_version=11,
        dynamic_axes={'input' : {0 : 'batch_size', 3 : 'w', 4 : 'h'}, 'output' : {0 : 'batch_size', 3 : 'dstw', 4 : 'dsth'}})


if __name__ == '__main__':
    main()

这里报错:

ValueError: SRGAN model does not support `forward_train` function.

直接将这个test_mode默认值改为Ture,让程序能走下去就行了。
RealBasicVSR模型转成ONNX以及用c++推理_第1张图片

测试生成的模型

这里已经得到了 realbasicvsr.onnx 模型文件了.

import onnxruntime as ort
import numpy as np
import onnx
import cv2


def main():
    onnx_model = onnx.load_model("./realbasicvsr.onnx")
    onnxstrongmodel = onnx_model.SerializeToString()
    sess = ort.InferenceSession(onnxstrongmodel)
    
    providers = ['CPUExecutionProvider']
    options = [{}]
    is_cuda_available = ort.get_device() == 'GPU'
    if is_cuda_available:
        providers.insert(0, 'CUDAExecutionProvider')
        options.insert(0, {'device_id': 0})
    sess.set_providers(providers, options)

    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[1].name
    print(sess.get_inputs()[0])
    print(sess.get_outputs()[0])
    print(sess.get_outputs()[0].shape)
    print(sess.get_inputs()[0].shape)

    img = cv2.imread("./data/img/test1.png")
    img = np.expand_dims((img/255.0).astype(np.float32).transpose(2,0,1), axis=0)
    imgs = np.array([img])
    print(imgs.shape)
    print(imgs)
    output = sess.run([output_name], {input_name : imgs})

    print(output)

    print(output[0].shape)
    output = np.clip(output, 0, 1)

    res = output[0][0][0].transpose(1, 2, 0)
    cv2.imwrite("./testout.png", (res * 255).astype(np.uint8))

if __name__ == '__main__':
    main()

至此模型转换部分就成功完成了

用ONNX Runtime c++推理

根据cuda版本选择合适的onnxruntime版本

下载onnx runtime的运行环境 onnxruntime
我这里下载这个:
RealBasicVSR模型转成ONNX以及用c++推理_第2张图片

#include 
#include 
#include 
#include 
#include 

class ONNX_RealBasicVSR{
  public:
    ONNX_RealBasicVSR():session(nullptr){};
    virtual ~ONNX_RealBasicVSR() = default;
    /*初始化
    * @param model_path 模型
    * @param gpu_id 选择用那块GPU
    */
    void Init(const char * model_path,int gpu_id = 0);

    /**执行模型推理
     * @param src : 输入图
     * @param inputid : 输入id
     * @param outputid : 输出的id
     * @return 输出结果图
    */
    cv::Mat Run(cv::Mat src,unsigned inputid = 0,unsigned outputid = 0,bool show_log = false);

  private:
    /*获取模型的inputname 或者 outputname
    * @param input_or_output  选择要获取的是input还是output
    * @param id 选择要返回的是第几个name
    * @param show_log 是否打印信息
    * @return 返回name
    */
    std::string GetInputOrOutputName(std::string input_or_output = "input",unsigned id = 0,bool show_log = false);

    /*获取模型的input或者output的shape信息
    * @param input_or_output  选择要获取的是input还是output
    * @param id 选择要返回的是第几个shape
    * @param show_log 是否打印信息
    * @return 返回shape信息
    */
    std::vector<int64_t> GetInputOrOutputShape(std::string input_or_output = "input",unsigned id = 0,bool show_log = false);


    mutable Ort::Session session;

    Ort::Env env;//(ORT_LOGGING_LEVEL_VERBOSE, "test");

};


void ONNX_RealBasicVSR::Init(const char * model_path,int gpu_id ){
    // env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "ONNX_RealBasicVSR");
    env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "ONNX_RealBasicVSR");
    Ort::SessionOptions session_options;
    // 使用五个线程执行op,提升速度
    session_options.SetIntraOpNumThreads(5);
    session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
    OrtCUDAProviderOptions cuda_option;
    cuda_option.device_id = gpu_id;
    session_options.AppendExecutionProvider_CUDA(cuda_option);
    //Ort::Session session(env, model_path, session_options);
    session = Ort::Session(env, model_path, session_options);

    return;
}

cv::Mat ONNX_RealBasicVSR::Run(cv::Mat src,unsigned inputid ,unsigned outputid ,bool show_log){
    int64_t H = src.rows;
    int64_t W = src.cols;
    cv::Mat blob;
    cv::dnn::blobFromImage(src, blob, 1.0 / 255.0, cv::Size(W, H), cv::Scalar(0, 0, 0), false, true); 
    // 创建tensor
    size_t input_tensor_size = blob.total();
    std::vector<float> input_tensor_values(input_tensor_size);
    
    //overwrite input dims
    std::vector<int64_t> input_node_dims = GetInputOrOutputShape("input",inputid,show_log);
    input_node_dims[0] = 1;
    input_node_dims[3] = W;
    input_node_dims[4] = H;

    for (size_t i = 0; i < input_tensor_size; ++i)
    {
        input_tensor_values[i] = blob.at<float>(i);
        // std::cout <<" " << input_tensor_values[i] ;
    }
    std::cout << std::endl;


    //查看输入的shape
    if(show_log){
        std::cout << "shape:";
        for(auto &i : input_node_dims){
            std::cout <<" " << i ;
        }
        std::cout << std::endl;
    }

    std::cout << "input_tensor_size" << input_tensor_size << std::endl;

    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
    auto input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, input_node_dims.data(), input_node_dims.size());

    std::string input_name = GetInputOrOutputName("input",inputid,show_log);
    std::string output_name = GetInputOrOutputName("output",outputid,show_log);
    const char* inputname[] = {input_name.c_str()}; //输入节点名
    const char* outputname[] = {output_name.c_str()}; //输出节点名

    std::vector<Ort::Value> output_tensor = session.Run(Ort::RunOptions{nullptr},inputname , &input_tensor, 1, outputname,1);
    
    if(show_log){
        //显示有几个输出的结果
        std::cout << "output_tensor_size: " << output_tensor.size() << std::endl;
    }
    
    //获取output的shape
    Ort::TensorTypeAndShapeInfo shape_info = output_tensor[0].GetTensorTypeAndShapeInfo();

    //获取output的dim
    size_t dim_count = shape_info.GetDimensionsCount();
    if(show_log){
        std::cout << dim_count << std::endl;
    }
    
    auto shape = shape_info.GetShape();
    if(show_log){
        //显示输出的shape信息
        std::cout<< "shape: " ;
        for(auto &i : shape){
            std::cout << i << " ";
        }
        std::cout << std::endl;
    }

    //取output数据
    float* f = output_tensor[0].GetTensorMutableData<float>();
 
    int output_width = shape[3];
    int output_height = shape[4]; 
    int size_pic = output_width * output_height;
    cv::Mat fin_img;
    std::vector<cv::Mat> rgbChannels(3);
    rgbChannels[0] = cv::Mat(output_height,output_width,CV_32FC1,f);
    rgbChannels[1] = cv::Mat(output_height,output_width,CV_32FC1,f + size_pic);
    rgbChannels[2] = cv::Mat(output_height,output_width,CV_32FC1,f + size_pic + size_pic);
    merge(rgbChannels,fin_img);
    fin_img = fin_img * 255;
    return fin_img;
}

std::string ONNX_RealBasicVSR::GetInputOrOutputName(std::string input_or_output,unsigned id ,bool show_log){
    size_t num_input_nodes = session.GetInputCount();
    size_t num_output_nodes = session.GetOutputCount();

    if(show_log){
        //显示模型有几个输入几个输出
        std::cout << "num_input_nodes:" << num_input_nodes << std::endl;
        std::cout << "num_output_nodes:" << num_output_nodes << std::endl;
    }

    std::vector<const char*> input_node_names(num_input_nodes);
    std::vector<const char*> output_node_names(num_output_nodes);

    Ort::AllocatorWithDefaultOptions allocator;
    std::string name;
    if(input_or_output == "input"){
        Ort::AllocatedStringPtr input_name_Ptr = session.GetInputNameAllocated(id, allocator);
        name = input_name_Ptr.get();
    }else{
        auto output_name_Ptr = session.GetOutputNameAllocated(id, allocator);
        name = output_name_Ptr.get();
    }

    if(show_log){
        std::cout << "name:" << name << std::endl;
    }
    
    return name;
}

std::vector<int64_t> ONNX_RealBasicVSR::GetInputOrOutputShape(std::string input_or_output,unsigned id,bool show_log){
    std::vector<int64_t> shape;
    if(input_or_output == "input"){
        Ort::TypeInfo type_info = session.GetInputTypeInfo(id);
        auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
        // 得到输入节点的数据类型
        ONNXTensorElementDataType type = tensor_info.GetElementType(); 
        if(show_log){
            std::cout << "input_type: " << type << std::endl;
        }
        shape = tensor_info.GetShape();
        if(show_log){
            std::cout << "intput shape:";
            for(auto &i : shape){
                std::cout <<" " << i ;
            }
            std::cout << std::endl;
        }
    }else{
        Ort::TypeInfo type_info_out = session.GetOutputTypeInfo(id);
        auto tensor_info_out = type_info_out.GetTensorTypeAndShapeInfo();
        // 得到输出节点的数据类型
        ONNXTensorElementDataType type_out = tensor_info_out.GetElementType();
        if(show_log){
            std::cout << "output type: " << type_out << std::endl;
        }    
        
        // 得到输出节点的输入维度 std::vector
        shape = tensor_info_out.GetShape();
        if(show_log){
            std::cout << "output shape:";
            for(auto &i : shape){
                std::cout <<" " << i ;
            }
            std::cout << std::endl;
        }
    }
    return shape;
}

#include "ONNX/liangbaikai_RealBasicVSR_onnx.hpp"
int main(){
    ONNX_RealBasicVSR orbv;
    orbv.Init("../realbasicvsr.onnx");
    cv::Mat img = cv::imread("../img/test1.png");
	unsigned inputid = 0;
	unsigned outputid = 1;
	int W = img.cols, H = img.rows;

    if(W > H){
        cv::copyMakeBorder(img,img,0,W - H,0,0,cv::BORDER_REFLECT_101);
     }else if(H > W){
         cv::copyMakeBorder(img,img,0,0,0,H-W,cv::BORDER_REFLECT_101);
     }

    cv::Mat res = orbv.Run(img,inputid,outputid,true);

     if(outputid == 1){
         res = res(cv::Rect(0,0,W * 4,H * 4));
     }else{
         res = res(cv::Rect(0,0,W,H));
     }
    cv::imwrite("./tttttttfin.png",res);
    return 0;
}

效果:

RealBasicVSR模型转成ONNX以及用c++推理_第3张图片
RealBasicVSR模型转成ONNX以及用c++推理_第4张图片

你可能感兴趣的:(人工智能,深度学习,超分辨率,realBasicVSR)