TensorRT 入门(2) 官方样例 sampleMNISTAPI

文章目录

    • 0. 前言
    • 1. TensoRT C++ 构建网络
      • 1.1. 对象定义
      • 1.2. 建立网络
    • 2. 其他相关代码
      • 2.1. 读取命令行参数
      • 2.2. logger 相关

0. 前言

  • 本文提到的sampleMNISTAPI与之前笔记提到的sampleMNIST有完全相同的输入与输出,不同之处在于模型创建方式不一样。
    • sampleMNIST通过导入一个caffe模型并将caffe模型转换为tensorrt的形式。
    • sampleMNISTAPI通过TensorRT的C++接口直接一层一层搭建模型,并将caffe中的权重导入创建好的网络中。
  • 使用C++构建网络其实有对应的官方文档,建议先看一遍。
  • 本文内容主要包括:
    • 如何通过 TensorRT C++ API 构建网络
    • 代码中其他相关内容

1. TensoRT C++ 构建网络

  • TensorRT相关内容都封装在 SampleMNISTAPI sample... 这个sample对象中。
  • 主要内容包括:
    • 对象定义。
    • 建立网络sample.build()
    • 模型推理sample.infer():之前介绍过(虽然还没懂),这里不多说了
    • 模型关闭sample.teardown():直接介绍过,这里不多说了

1.1. 对象定义

  • 入口就是 SampleMNISTAPI sample(initializeSampleParams(args));

  • 包括两步:

    • 第一步:使用命令行参数构建SampleMNISTAPI的初始化参数。
    • 第二步:初始化SampleMNISTAPI对象。
  • 第一步主要就是构建了一个 SampleMNISTAPIParams对象

    • 随便看看有哪些变量以及初始值就可以了,没啥大不了的
SampleMNISTAPIParams initializeSampleParams(const samplesCommon::Args& args)
{
    SampleMNISTAPIParams params;
    if (args.dataDirs.empty()) //!< Use default directories if user hasn't provided directory paths
    {
        params.dataDirs.push_back("data/mnist/");
        params.dataDirs.push_back("data/samples/mnist/");
    }
    else //!< Use the data directory provided by the user
    {
        params.dataDirs = args.dataDirs;
    }
    params.inputTensorNames.push_back("data");
    params.batchSize = 1;
    params.outputTensorNames.push_back("prob");
    params.dlaCore = args.useDLACore;
    params.int8 = args.runInInt8;
    params.fp16 = args.runInFp16;

    params.inputH = 28;
    params.inputW = 28;
    params.outputSize = 10;
    params.weightsFile = "mnistapi.wts";
    params.mnistMeansProto = "mnist_mean.binaryproto";

    return params;
}
  • 第二步更加简单,只是指定了两个成员变量的值,分别是mParams(params)/mEngine(nullptr)
  • 这里顺便介绍一下 SampleMNISTAPI的成员变量以及函数

1.2. 建立网络

  • 入口就是sample.build() 方法,主要流程如下:
//!
//! \brief Creates the network, configures the builder and creates the network engine
//!
//! \details This function creates the MNIST network by using the API to create a model and builds
//!          the engine that will be used to run MNIST (mEngine)
//!
//! \return Returns true if the engine was created successfully and false otherwise
//!
bool SampleMNISTAPI::build()
{
    // locateFile 函数用于权重文件位置,返回的是字符串路径
    // loadWeights 就是读取指定位置的权重文件,返回的是一个 std::map 对象
    // 我也不知道权重文件的具体格式,大概就是先读取一个整数(表示总函数数量),然后依次读取每个参数
    // 每个参数内容包括 name/type/size,分别是字符串/int/uint32
    // 这是一个成员变量
    mWeightMap = loadWeights(locateFile(mParams.weightsFile, mParams.dataDirs));

    // 创建 builder,准备构建网络
    auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
    if (!builder)
    {
        return false;
    }

    // 创建 Network 对象,后面网络就保存在这里
    auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetwork());
    if (!network)
    {
        return false;
    }

    // 创建构建网络所需的参数对象
    auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
    if (!config)
    {
        return false;
    }

    // 构建网络
    auto constructed = constructNetwork(builder, network, config);
    
    // 验证网络是否构建成功
    if (!constructed)
    {
        return false;
    }

    // 网络只有一个输入,输入是一个三维对象,只有一个输出,输出也是三维对象
    assert(network->getNbInputs() == 1);
    auto inputDims = network->getInput(0)->getDimensions();
    assert(inputDims.nbDims == 3);
    assert(network->getNbOutputs() == 1);
    auto outputDims = network->getOutput(0)->getDimensions();
    assert(outputDims.nbDims == 3);

    return true;
}
  • 主要过程都在 constructNetwork中。
    • 内容平平无奇,就是一层一层构建网络,跟PyTorch啥的也没区别。
    • 注意,这里其实也包括了模型参数导入过程。
    • 官方文档里也有介绍。
//!
//! \brief Uses the API to create the MNIST Network
//!
//! \param network Pointer to the network that will be populated with the MNIST network
//!
//! \param builder Pointer to the engine builder
//!
bool SampleMNISTAPI::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
    SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config)
{
    // Create input tensor of shape { 1, 1, 28, 28 }
    ITensor* data = network->addInput(
        mParams.inputTensorNames[0].c_str(), DataType::kFLOAT, Dims3{1, mParams.inputH, mParams.inputW});
    assert(data);

    // Create scale layer with default power/shift and specified scale parameter.
    const float scaleParam = 0.0125f;
    const Weights power{DataType::kFLOAT, nullptr, 0};
    const Weights shift{DataType::kFLOAT, nullptr, 0};
    const Weights scale{DataType::kFLOAT, &scaleParam, 1};
    IScaleLayer* scale_1 = network->addScale(*data, ScaleMode::kUNIFORM, shift, scale, power);
    assert(scale_1);

    // Add convolution layer with 20 outputs and a 5x5 filter.
    IConvolutionLayer* conv1 = network->addConvolutionNd(
        *scale_1->getOutput(0), 20, Dims{2, {5, 5}, {}}, mWeightMap["conv1filter"], mWeightMap["conv1bias"]);
    assert(conv1);
    conv1->setStride(DimsHW{1, 1});

    // Add max pooling layer with stride of 2x2 and kernel size of 2x2.
    IPoolingLayer* pool1 = network->addPoolingNd(*conv1->getOutput(0), PoolingType::kMAX, Dims{2, {2, 2}, {}});
    assert(pool1);
    pool1->setStride(DimsHW{2, 2});

    // Add second convolution layer with 50 outputs and a 5x5 filter.
    IConvolutionLayer* conv2 = network->addConvolutionNd(
        *pool1->getOutput(0), 50, Dims{2, {5, 5}, {}}, mWeightMap["conv2filter"], mWeightMap["conv2bias"]);
    assert(conv2);
    conv2->setStride(DimsHW{1, 1});

    // Add second max pooling layer with stride of 2x2 and kernel size of 2x3>
    IPoolingLayer* pool2 = network->addPoolingNd(*conv2->getOutput(0), PoolingType::kMAX, Dims{2, {2, 2}, {}});
    assert(pool2);
    pool2->setStride(DimsHW{2, 2});

    // Add fully connected layer with 500 outputs.
    IFullyConnectedLayer* ip1
        = network->addFullyConnected(*pool2->getOutput(0), 500, mWeightMap["ip1filter"], mWeightMap["ip1bias"]);
    assert(ip1);

    // Add activation layer using the ReLU algorithm.
    IActivationLayer* relu1 = network->addActivation(*ip1->getOutput(0), ActivationType::kRELU);
    assert(relu1);

    // Add second fully connected layer with 20 outputs.
    IFullyConnectedLayer* ip2 = network->addFullyConnected(
        *relu1->getOutput(0), mParams.outputSize, mWeightMap["ip2filter"], mWeightMap["ip2bias"]);
    assert(ip2);

    // Add softmax layer to determine the probability.
    ISoftMaxLayer* prob = network->addSoftMax(*ip2->getOutput(0));
    assert(prob);
    prob->getOutput(0)->setName(mParams.outputTensorNames[0].c_str());
    network->markOutput(*prob->getOutput(0));

    // Build engine
    builder->setMaxBatchSize(mParams.batchSize);
    config->setMaxWorkspaceSize(16_MiB);
    if (mParams.fp16)
    {
        config->setFlag(BuilderFlag::kFP16);
    }
    if (mParams.int8)
    {
        config->setFlag(BuilderFlag::kINT8);
        samplesCommon::setAllTensorScales(network.get(), 64.0f, 64.0f);
    }

    samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);

    mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(
        builder->buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter());
    if (!mEngine)
    {
        return false;
    }

    return true;
}

2. 其他相关代码

  • 对C++不熟悉,代码总是要慢慢看起来,慢慢写起来的。
  • 毕竟之后还要研究其他samples,不如就从这里入手,慢慢看。

2.1. 读取命令行参数

  • 主要引用了 samples/common/args.Parser.h 中相关的内容。
  • 样例中的代码很简单,在main函数中主要代码是:
// 新建保存命令行参数的struct对象
samplesCommon::Args args;

// 解析命令行参数内容,结果保存在args中
// 后续引用命令行参数就是用类似 `args.xxx` 的形式
bool argsOK = samplesCommon::parseArgs(args, argc, argv);

// 判断前一步解析是否成功(有没有多余参数啥的)
if (!argsOK)
{
    // 解析没有成功就返回帮助文档,告诉用于有什么错误
    sample::gLogError << "Invalid arguments" << std::endl;
    printHelpInfo();
    return EXIT_FAILURE;
}

// 如果输入
if (args.help)
{
    printHelpInfo();
    return EXIT_SUCCESS;
}
  • 命令行参数包括哪些?
    • 指的就是 samplesCommon::Args 对象,查看源码后发现就是一个struct,里面的每个参数就是一个命令行参数。
struct SampleParams
{
    int32_t batchSize{1};              //!< Number of inputs in a batch
    int32_t dlaCore{-1};               //!< Specify the DLA core to run network on.
    bool int8{false};                  //!< Allow runnning the network in Int8 mode.
    bool fp16{false};                  //!< Allow running the network in FP16 mode.
    std::vector<std::string> dataDirs; //!< Directory paths where sample data files are stored
    std::vector<std::string> inputTensorNames;
    std::vector<std::string> outputTensorNames;
};
  • 命令行参数解析过程关键代码
// 定义一组option对象列表,option是getopt中定义的结构体,具体定义如下
//struct option
//{
//  const char *name;
//  /* has_arg can't be an enum because some compilers complain about
//     type mismatches in all the code that assumes it is an int.  */
//  int has_arg;
//  int *flag;
//  int val;
//};
static struct option long_options[] = {{"help", no_argument, 0, 'h'}, {"datadir", required_argument, 0, 'd'}, {"int8", no_argument, 0, 'i'}, {"fp16", no_argument, 0, 'f'}, {"useILoop", no_argument, 0, 'l'}, {"saveEngine", required_argument, 0, 's'}, {"loadEngine", no_argument, 0, 'o'}, {"useDLACore", required_argument, 0, 'u'}, {"batch", required_argument, 0, 'b'}, {nullptr, 0, nullptr, 0}};

// 解析,得到结果 args
// 另外,如果包括了命令行参数,那参数包括在 optarg 中,这个参数定义在getopt.h中,可直接饮用
arg = getopt_long(argc, argv, "hd:iu", long_options, &option_index);

switch (arg)
{
    case 'h': ...;
    case 'd': do_something_with(optarg);
    ...
}

2.2. logger 相关

  • 日志相关代码都是TnesorRT samples中自带的,这里稍微分析一下。

  • 相关代码主要包括

// 记录一次测试的logger信息
// 请注意,好像实际输出日志文件的,并不是这个对象,而是 sample::gLogger
auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);

// 表示一次测试开始了,测试开始会有一串字符串显示,包括名称以及命令行参数
sample::gLogger.reportTestStart(sampleTest);

// 这个对象实现了tensorrt推理功能
SampleMNISTAPI sample(initializeSampleParams(args));

// 普通的展示日志信息
// 从这里也可以看到,输出日志信息用的是 sample::gLogger,跟sampleTest没啥关系
sample::gLogInfo << "Building and running a GPU inference engine for MNIST API" << std::endl;

// 后面所有 sample 相关操作结果都通过sampleTest
// sampleTest的作用在于,包含了一个 `reportTestEnd` 函数,定义了输出日志的格式
if (!sample.build())
{
    return sample::gLogger.reportFail(sampleTest);
}
if (!sample.infer())
{
    return sample::gLogger.reportFail(sampleTest);
}
if (!sample.teardown())
{
    return sample::gLogger.reportFail(sampleTest);
}
return sample::gLogger.reportPass(sampleTest);
  • 要理解这些内容,首先一个问题是,sampleTest倒是是个啥
    • 通过工厂方法,创建了TestAtom类的对象
    • 这个对象主要包括三个属性,mStarted/mName/mCmdline,后两个都是字符串,分别表示名称以及所有命令行信息(命令行参数之间用空格分割)。

你可能感兴趣的:(C++,TensorRT,TensorRT,模型部署)