ONNX模型转TRT部署推理c++

训练好的模型(如.pt)转成onnx形式,ONNX定义了一组与环境和平台无关的标准格式。ONNX文件不仅存储了神经网络模型的权重,还存储了模型的结构信息、网络中各层的输入输出等一些信息。

ONNX的推理可以用ONNX Runtime官方库,如果在英伟达平台上,可以转TensorRT后运行。本文主要介绍转TRT格式后如何C++部署运行。

1、ONNX 转 RTR

基本流程就是:
1、创建构建器,由构建器创建网络,然后解析器解析ONNX文件。
2、设置一些必要参数
3、构建其构建网络然后保存成TRT模型
用到的logging.h 文件直接用NVIDIA自带的。

#include "NvInfer.h"
#include "NvOnnxParser.h"
#include "logging.h"
#include  
#include  
#include   
using namespace std;
using namespace nvonnxparser;
using namespace nvinfer1;

#define USE_FP16
static Logger gLogger;

void saveToTrtModel(std::string trt_save_path,IHostMemory*trtModelStream){
    std::ofstream out(trt_save_path, std::ios::binary);
    if (!out.is_open()){
    std::cout << "打开文件失败!" <<std:: endl;
    }
    out.write(reinterpret_cast<const char*>(trtModelStream->data()), trtModelStream->size());
    out.close();
}

int onnx2trt(){
    std::string onnx_path = "../onnx_model/plate_detect.onnx";
    std::string trt_save_path = "../onnx_model/plate_detect.trt";
    int batch_size = 1;
    IBuilder * builder = createInferBuilder(gLogger);
    INetworkDefinition *network = builder->createNetworkV2(1U);
    // 解析模型
    IParser *parser = nvonnxparser::createParser(*network, gLogger);
    if(!parser->parseFromFile(onnx_path.c_str(), (int)nvinfer1::ILogger::Severity::kWARNING)){
        std::cout << " parse onnx file fail ..." << std::endl;
        return -1;
    }
    IBuilderConfig *config = builder->createBuilderConfig();
    builder->setMaxBatchSize(batch_size);
    config->setMaxWorkspaceSize(1<<30);
    auto profile = builder->createOptimizationProfile();
    auto input_tensor=network->getInput(0);
    auto input_dims = input_tensor->getDimensions();
    input_dims.d[0] = 1;
     profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kMIN, input_dims);
    profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kOPT, input_dims);
    input_dims.d[0] = batch_size;
    profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kMAX, input_dims);
    config->addOptimizationProfile(profile);

#ifdef USE_FP16
    config->setFlag(BuilderFlag::kFP16);
#endif
#ifdef USE_INT8
    config->setFlag(BuilderFlag::kINT8);
#endif
    ICudaEngine *engine = builder->buildEngineWithConfig(*network, *config);
    assert(engine);
    IHostMemory* trtModelStream = engine->serialize(); //序列化 保存trt
    saveToTrtModel(trt_save_path.c_str(), trtModelStream);
    parser->destroy();
    engine->destroy();
    network->destroy();
    config->destroy();
    builder->destroy();

    return 0;
}
cmake_minimum_required(VERSION 3.10)
project(onnx2trt)
include_directories(/home/max/TensorRT-7.2.1.6/include)
link_directories(/home/max/TensorRT-7.2.1.6/lib)
add_executable(onnx2trt onnx2trt.cpp)
target_link_libraries(onnx2trt nvinfer)
target_link_libraries(onnx2trt cudart)
target_link_libraries(onnx2trt nvonnxparser)

2、 RTR 模型推理

转好的TRT模型在部署工程上推理运行。
基本流程:
1、trt从文件中解析出模型,并反序列化到推理CUDA推理引擎。
2、分配推理所需要的CPU、GPU内存空间
3、引擎推理获取结果
调用以下类接口即可推理。

#include "logging.h"
#include "NvInfer.h"
#include "NvOnnxParser.h"
#include 
#include  // cuda include
#include  
#include  
#include 
using namespace nvinfer1; 
static Logger gLogger;

class TrtDetct
{
private:
    char *_trtModelStream{nullptr};
    IRuntime* _runtime = nullptr;
    ICudaEngine* _engine=nullptr;
    IExecutionContext* _context=nullptr;
    void *_inferbuffers[2];
    int _outputSize = 0;
    int _input_h = 640;
    int _input_w = 640;
    cudaStream_t _stream;

private:
 int getoutputSize(){
    auto out_dims = _engine->getBindingDimensions(1);
    int outputSize = 1;
    for(int j = 0; j < out_dims.nbDims; j++) {
        std::cout << "j = " << j << " size = " << out_dims.d[j] << std::endl;
        outputSize *= out_dims.d[j];
    }
    return outputSize;
}
public:
    TrtDetct(/* args */){};
    ~TrtDetct(){
        if (nullptr != _trtModelStream){
            delete [] _trtModelStream;
        }
    };
    // 文件读取模型,并反序列化成engine
    void load_trtmodel(std::string trt_model_path){
        std::ifstream file(trt_model_path, std::ios::binary);
        size_t size{0};
        if (file.good()) {
                file.seekg(0, file.end);
                size = file.tellg();
                file.seekg(0, file.beg);
                _trtModelStream = new char[size];
                assert(_trtModelStream);
                file.read(_trtModelStream, size);
                file.close();
        }
    _runtime = createInferRuntime(gLogger);
    assert(_runtime != nullptr);
    _engine = _runtime->deserializeCudaEngine(_trtModelStream, size);
    assert(_engine != nullptr); 
    _context = _engine->createExecutionContext();
    assert(_context != nullptr);
    initbuff();
    }

    //分配处理相关内存
    void initbuff(){
        _outputSize = getoutputSize();
        // 这两个值在生成onnx时刻已经固定
        const int inputIndex = _engine->getBindingIndex("input");
        const int outputIndex = _engine->getBindingIndex("output");
        assert(inputIndex == 0);
        assert(outputIndex == 1);
        CHECK(cudaMalloc((void**)&_inferbuffers[inputIndex],  3 * _input_h * _input_w * sizeof(float)));  //trt输入内存申请
        CHECK(cudaMalloc((void**)&_inferbuffers[outputIndex], _outputSize * sizeof(float)));           //trt输出内存申请
        CHECK(cudaStreamCreate(&_stream));
    }
    // 推理
    void infer_trtmodel(){
        //图像数据填充_inferbuffers[0],GPU CUDA处理
        _context->enqueueV2((void **)_inferbuffers, _stream, nullptr);
        //_inferbuffers[1]模型输出后处理,可以GPU处理,否则拷贝到cpu处理
    }
};

CMakeLists.txt 文件如下:

project(trt_detect)
add_definitions(-std=c++11)
add_definitions(-w)
find_package(CUDA REQUIRED)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_BUILD_TYPE Release)
#cuda 
include_directories(/usr/local/cuda/include)
link_directories(/usr/local/cuda/lib64)
include_directories(/home/max/TensorRT-7.2.1.6/include)
link_directories(/home/max/TensorRT-7.2.1.6/lib)
cuda_add_executable(trt_detect trt_detect.cpp)

target_link_libraries(trt_detect nvinfer)
target_link_libraries(trt_detect cudart)
target_link_libraries(trt_detect nvonnxparser)
add_definitions(-O2)

你可能感兴趣的:(c++,开发语言,深度学习,人工智能)