TensorRT 入门(3) 官方样例 sampleOnnxMNIST

文章目录

    • 0. 前言
    • 1. ONNX 模型转换
      • 1.1. build 函数详解

0. 前言

  • 本文提到的sampleMNISTAPI与之前0. 前言

    • 本文提到的sampleMNISTAPI与之前笔记1和笔记2提到的样例有完全相同的输入与输出,不同之处在于模型创建方式不一样。
      • sampleMNIST通过导入一个caffe模型并将caffe模型转换为tensorrt的形式。
      • sampleMNISTAPI通过TensorRT的C++接口直接一层一层搭建模型,并将caffe中的权重导入创建好的网络中。
      • sampleOnnxMNIST通过ONNX构建模型。
  • 一点点疑问:TensorRT要使用ONNX模型应该有两种方式,一种是像本例一样,直接在程序中转换ONNX模型形式,另外还有一种是通过官方工具先将ONNX模型转换为engine文件,不知道这两种方式有什么区别。

1. ONNX 模型转换

  • 其他代码都不说了,就仔细看看 SampleOnnxMNIST::build() 函数。

1.1. build 函数详解

  • 构建网络的流程基本上是
    • 构建builder
    • 构建空白network对象
    • 构建buildConfig参数
    • 构建Onnx模型解析器
    • 通过解析器将模型结构保存在network对象中
    • 设置一些模型参数(比如模型量化)
    • 验证结果
bool SampleOnnxMNIST::build()
{
    // 构建模型builder
    auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
    if (!builder)
    {
        return false;
    }

    // 构建空白network对象
    const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
    auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
    if (!network)
    {
        return false;
    }

    // 创建BuildConfig,我也不知道是干啥用的
    auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
    if (!config)
    {
        return false;
    }

    // 构建Onnx模型解析器
    auto parser
        = SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger()));
    if (!parser)
    {
        return false;
    }

    // 构建模型,通过parser解析,并将解析结果导入network中
    auto constructed = constructNetwork(builder, network, config, parser);
    if (!constructed)
    {
        return false;
    }

    mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(
        builder->buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter());
    if (!mEngine)
    {
        return false;
    }

    // 验证结果
    assert(network->getNbInputs() == 1);
    mInputDims = network->getInput(0)->getDimensions();
    assert(mInputDims.nbDims == 4);
    assert(network->getNbOutputs() == 1);
    mOutputDims = network->getOutput(0)->getDimensions();
    assert(mOutputDims.nbDims == 2);

    return true;
}
  • 前一步的核心就是constrctNetwork,即通过parser解析模型并保存到network中
//!
//! \brief Uses a ONNX parser to create the Onnx MNIST Network and marks the
//!        output layers
//!
//! \param network Pointer to the network that will be populated with the Onnx MNIST network
//!
//! \param builder Pointer to the engine builder
//!
bool SampleOnnxMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
    SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
    SampleUniquePtr<nvonnxparser::IParser>& parser)
{
    // 注意,构建解析器的时候就已经把network对象作为参数传入了
    auto parsed = parser->parseFromFile(locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(),
        static_cast<int>(sample::gLogger.getReportableSeverity()));
    if (!parsed)
    {
        return false;
    }

    // 模型量化,不知道跟onnx_tensorrt工具有啥区别
    config->setMaxWorkspaceSize(16_MiB);
    if (mParams.fp16)
    {
        config->setFlag(BuilderFlag::kFP16);
    }
    if (mParams.int8)
    {
        config->setFlag(BuilderFlag::kINT8);
        samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f);
    }

    // 这里的 DLA 就是 Deep Learning Accelerator
    // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#dla_layers
    samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);

    return true;
}

你可能感兴趣的:(TensorRT,C++,TensorRT,模型部署)