本文首发于https://yangzhenyu.tech/
作者:Frank
导读:本文主要带来对TensorRT中自带的sample:sampleOnnxMNIST的源码解读,官方例程是非常好的学习资料,通过吃透一个官方例程,就可以更加深刻地了解TensorRT的每一步流程,明白其中套路,再去修改代码推理我们自己的网络就是很容易的事情了。
TensorRT的一份源码要怎么看?TensorRT是一个很具有操作流程性质的代码,我们从main函数开始着手,顺着执行的流程步骤向前推,遇到各种功能的嵌套,就像剥洋葱一样一层一层剥开,了解其具体细节后再回到主线。我们以BFS(广度优先遍历)的方式看代码,不仅不会忘记主线任务,还能明白各种细节,从而真正吃透代码。
前排提示:阅读本文代码解析时经常需要深入代码,然后再跳出来,所以读者阅读时请多关注内容在目录中的层级关系,方便梳理。最好可以用VS Code自己打开一份SampleOnnxMNIST的源码,跟着笔者一起看,效果会很好。
main函数遵循了用TensorRT推理的基本步骤:
首先我们来瞅一眼main函数的大框架,大致了解各个部分代码是在干什么的:
int main(int argc, char** argv)
{
samplesCommon::Args args; // 接收用户传递参数的变量
bool argsOK = samplesCommon::parseArgs(args, argc, argv); // 将main函数的参数argc和argv解释成args,返回转换是否成功的bool值
if (!argsOK) // 如果转换不成功,则用日志类报错并打印帮助信息,退出程序。
{
sample::gLogError << "Invalid arguments" << std::endl;
printHelpInfo();
return EXIT_FAILURE;
}
if (args.help) // 如果接收的参数是请求打印帮助信息,则打印帮助信息,退出程序。
{
printHelpInfo();
return EXIT_SUCCESS;
}
auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv); // 定义一个日志类
sample::gLogger.reportTestStart(sampleTest); // 记录日志的开始
SampleOnnxMNIST sample(initializeSampleParams(args)); // 定义一个sample实例
sample::gLogInfo << "Building and running a GPU inference engine for Onnx MNIST" << std::endl;
if (!sample.build()) // 【主要】在build方法中构建网络,返回构建网络是否成功的状态
{
return sample::gLogger.reportFail(sampleTest);
}
if (!sample.infer()) // 【主要】读取图像并进行推理,返回推理是否成功的状态
{
return sample::gLogger.reportFail(sampleTest);
}
return sample::gLogger.reportPass(sampleTest); // 报告结束
}
好了,大模样有了,我们把它分成几块代码,现在开始吃它!
samplesCommon::Args args; // 接收用户传递参数的变量
bool argsOK = samplesCommon::parseArgs(args, argc, argv); // 将main函数的参数argc和argv解释成args,返回转换是否成功的bool值
if (!argsOK) // 如果转换不成功,则用日志类报错并打印帮助信息,退出程序。
{
sample::gLogError << "Invalid arguments" << std::endl;
printHelpInfo();
return EXIT_FAILURE;
}
if (args.help) // 如果接收的参数是请求打印帮助信息,则打印帮助信息,退出程序。
{
printHelpInfo();
return EXIT_SUCCESS;
}
首先实例化了一个samplesCommon
类型的变量args
,该类的定义在TensorRT中Samples/common/argsParser.h
文件中,这个common文件夹下定义了大量TensorRT中常用的操作和类,Args就是其中之一,是一个用管理用户传递进程序的参数的类。
Args类定义在argsParser.h文件的samplesCommon命名空间中,具体定义如下:
//!
//! /brief Struct to maintain command-line arguments.
//!
struct Args
{
bool runInInt8{false}; // 用INT8精度运行
bool runInFp16{false}; // 用FP16精度运行
bool help{false}; // 打印help信息
int32_t useDLACore{-1}; // 使用DLA核
int32_t batch{1}; // batch的大小
std::vector dataDirs; // 数据文件夹的位置
std::string saveEngine; // 存储引擎
std::string loadEngine; // 加载引擎
bool useILoop{false}; // TODO 未知
};
samplesCommon::parseArgs()
函数接收三个参数,分别为args和argc、argv,返回一个执行状态标识。这很容易理解,argc、argv是main函数的参数,该函数把argc和argv解读成Args类,便于后续操作。该函数也定义在argsParser.h文件的samplesCommon命名空间下,具体定义太冗长就不看了,只要理解它是将程序传入的用户参数解释成Args类型就行了。
然后分别是判断参数传递是否正确,以及如果传入请求打印信息,则返回打印信息的条件判断。
auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);
sample::gLogger.reportTestStart(sampleTest);
auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);
这一句是定义了一个日志类,名称为sampleTest,我们看看它在代码里的几处用到的地方:
auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);
...
sample::gLogger.reportTestStart(sampleTest);
...
return sample::gLogger.reportFail(sampleTest);
...
return sample::gLogger.reportPass(sampleTest);
看起来都是一些不痛不痒的东西,结合它的名称gLogger,基本可以断定它就是一个记录日志的类了。TensorRT中把它的具体实现放在了common文件夹中,感兴趣的同学可以自己看一下,不是很难。不看的话只需要知道它的功能是记录日志就行。其中gSampleName是本文件定义的一个全局静态变量,const std::string gSampleName = "TensorRT.sample_onnx_mnist";
是用来指示这个日志文件的记录内容的。
后面在创建完sample后还跑一行sample::gLogInfo << "Building and running a GPU inference engine for Onnx MNIST" << std::endl;
,很简单,通过对sample::gLogInfo输入流,就能在屏幕上打印出log如下:
[12/21/2020-16:19:21] [I] Building and running a GPU inference engine for Onnx MNIST
sample可以理解成是当前推理MNIST程序的一个最大的封装了,它把整个程序封装成了一个sample样例,这样做也说得过去。
SampleOnnxMNIST sample(initializeSampleParams(args));
这里自定义了一个类SampleOnnxMNIST
,并且调用了一个初始化sample参数的函数,该函数接收args作为参数。
我们先不着急看类的定义,先来看看这个类的构造函数接收的参数initializeSampleParams(args)
吧,它比类的定义简单,而且也很重要。来看该函数的定义:
//!
//! \brief Initializes members of the params struct using the command line args
//!
samplesCommon::OnnxSampleParams initializeSampleParams(const samplesCommon::Args& args)
{
samplesCommon::OnnxSampleParams params;
if (args.dataDirs.empty()) //!< Use default directories if user hasn't provided directory paths
{
params.dataDirs.push_back("data/mnist/");
params.dataDirs.push_back("data/samples/mnist/");
}
else //!< Use the data directory provided by the user
{
params.dataDirs = args.dataDirs;
}
params.onnxFileName = "mnist.onnx";
params.inputTensorNames.push_back("Input3");
params.outputTensorNames.push_back("Plus214_Output_0");
params.dlaCore = args.useDLACore;
params.int8 = args.runInInt8;
params.fp16 = args.runInFp16;
return params;
}
它创建了一个samplesCommon::OnxxSampleParams类型的变量,然后根据args设定各种参数。samplesCommon::OnxxSampleParams类型的定义在common.h中,如下:
struct OnnxSampleParams : public SampleParams
{
std::string onnxFileName; //!< Filename of ONNX file of a network
};
这是继承自SampleParams结构体的,只不过新增了一个onxxFileName成员变量,我们看看继承的结构体SampleParams
struct SampleParams
{
int32_t batchSize{1}; //!< Number of inputs in a batch
int32_t dlaCore{-1}; //!< Specify the DLA core to run network on.
bool int8{false}; //!< Allow runnning the network in Int8 mode.
bool fp16{false}; //!< Allow running the network in FP16 mode.
std::vector dataDirs; //!< Directory paths where sample data files are stored
std::vector inputTensorNames;
std::vector outputTensorNames;
};
也只不过是封装了几个成员变量而已,没啥大不了的。
我们再来看看巨庞大的SampleOnnxMNIST
类,这个就是我们程序的核心类了,封装了大量重要的功能。
//! \brief The SampleOnnxMNIST class implements the ONNX MNIST sample
//!
//! \details It creates the network using an ONNX model
//!
class SampleOnnxMNIST
{
template
using SampleUniquePtr = std::unique_ptr;
// using关键字是c++11中为类取别名的新关键字
// std::unique_ptr是智能指针的关键字
public:
SampleOnnxMNIST(const samplesCommon::OnnxSampleParams& params)
: mParams(params)
, mEngine(nullptr)
{
}
//!
//! \brief Function builds the network engine
//!
bool build();
//!
//! \brief Runs the TensorRT inference engine for this sample
//!
bool infer();
private:
samplesCommon::OnnxSampleParams mParams; //!< The parameters for the sample.
nvinfer1::Dims mInputDims; //!< The dimensions of the input to the network.
nvinfer1::Dims mOutputDims; //!< The dimensions of the output to the network.
int mNumber{0}; //!< The number to classify
std::shared_ptr mEngine; //!< The TensorRT engine used to run the network
//!
//! \brief Parses an ONNX model for MNIST and creates a TensorRT network
//!
bool constructNetwork(SampleUniquePtr& builder,
SampleUniquePtr& network, SampleUniquePtr& config,
SampleUniquePtr& parser);
//!
//! \brief Reads the input and stores the result in a managed buffer
//!
bool processInput(const samplesCommon::BufferManager& buffers);
bool prepareInput(const samplesCommon::BufferManager& buffers);
//!
//! \brief Classifies digits and verify result
//!
bool verifyOutput(const samplesCommon::BufferManager& buffers);
};
第一句给智能指针用using关键字起了个新的名字,方便后面使用
template
using SampleUniquePtr = std::unique_ptr;
这里将创建一个指向模版类T类型、销毁方式为samplesCommon::InferDeleter
的智能指针unique_ptr的声明,用using
关键字重命名为SampleUniquePtr
SampleOnnxMNIST(const samplesCommon::OnnxSampleParams& params)
: mParams(params),
mEngine(nullptr)
{}
这是SampleOnnxMNIST
类的构造函数,接收一个samplesCommon::OnnxSampleParams
类作为参数,并用它来初始化成员mParams
,默认成员mEngine
为nullptr
然后是两个最核心的成员函数build()
和infer()
,分别用于构建网络以及进行推理。后面我们再详细讲这两个最关键的函数。
//!
//! \brief Function builds the network engine
//!
bool build();
//!
//! \brief Runs the TensorRT inference engine for this sample
//!
bool infer();
再来看私有成员:
samplesCommon::OnnxSampleParams mParams; //!< The parameters for the sample.
说的很清楚了,是该sample的参数。
nvinfer1::Dims mInputDims; //!< The dimensions of the input to the network.
nvinfer1::Dims mOutputDims; //!< The dimensions of the output to the network.
int mNumber{0}; //!< The number to classify 存储读取的手写数字图像的具体数字的gt
mInputDims
和mOutputDims
指的是输入和输出Tensor的维度信息,它们的类型是nvinfer1::Dims类型,Dims类型的定义如下,在./include/NvInferRuntimeCommom.h
文件下
class Dims
{
public:
static const int32_t MAX_DIMS = 8; //!< The maximum number of dimensions supported for a tensor.
int32_t nbDims; //!< The number of dimensions.
int32_t d[MAX_DIMS]; //!< The extent of each dimension.
TRT_DEPRECATED DimensionType type[MAX_DIMS]; //!< The type of each dimension, provided for backwards compatibility
//!< and will be removed in TensorRT 8.0.
};
std::shared_ptr
定义的是一个用来run网络的engine,是一个指向nvinfer1::IcudaEngine
类型的智能指针,它是具体的网络结构以及参数设定的更上层的封装。
//!
//! \brief Parses an ONNX model for MNIST and creates a TensorRT network
//!
bool constructNetwork(SampleUniquePtr& builder,
SampleUniquePtr& network, SampleUniquePtr& config,
SampleUniquePtr& parser);
//!
//! \brief Reads the input and stores the result in a managed buffer
//!
bool processInput(const samplesCommon::BufferManager& buffers);
//!
//! \brief Classifies digits and verify result
//!
bool verifyOutput(const samplesCommon::BufferManager& buffers);
constructNetwork
解释一个ONNX模型成TensorRT的网络模型
processInput
里实现输入的读取和处理
verifyOutput
对推理结果的输出进行验证
好了,分析完巨长的SampleOnnxMNIST
类,我们再回到主线任务main中
SampleOnnxMNIST sample(initializeSampleParams(args));
。在该类的构造函数中使用了成员列表初始化的方式,接收一个const samplesCommon::OnnxSampleParams& params
类型的成员变量。实际代码中,我们看见了它是一个接收args作为参数的函数initializeSampleParams
的返回值,这个函数本质就是把args再解析了,然后创建一个samplesCommon::OnnxSampleParams类型的变量,作为sample的参数,把解析好的args的值传给这个参数,然后再传给sample用来初始化sample(args->params->sample:感觉有点像在套娃,我想说:禁止套娃)
接下来在main函数中执行了sample.build()
,什么叫build一个sample呢,可以理解成搭建了一个能够执行推理的引擎,而这个引擎的构建过程涉及网络的读取和构建、一系列配置参数的设定等,它是TensorRT中最为重要的准备步骤,build好之后就可以进行infer了。来看函数源码:
//!
//! \brief Creates the network, configures the builder and creates the network engine
//!
//! \details This function creates the Onnx MNIST network by parsing the Onnx model and builds
//! the engine that will be used to run MNIST (mEngine)
//!
//! \return Returns true if the engine was created successfully and false otherwise
//!
bool SampleOnnxMNIST::build()
{
auto builder = SampleUniquePtr (nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
if (!builder)
{
return false;
}
const auto explicitBatch = 1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto network = SampleUniquePtr(builder->createNetworkV2(explicitBatch));
if (!network)
{
return false;
}
auto config = SampleUniquePtr(builder->createBuilderConfig());
if (!config)
{
return false;
}
auto parser
= SampleUniquePtr(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger()));
if (!parser)
{
return false;
}
auto constructed = constructNetwork(builder, network, config, parser);
if (!constructed)
{
return false;
}
mEngine = std::shared_ptr(
builder->buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter());
if (!mEngine)
{
return false;
}
assert(network->getNbInputs() == 1);
mInputDims = network->getInput(0)->getDimensions();
assert(mInputDims.nbDims == 4);
assert(network->getNbOutputs() == 1);
mOutputDims = network->getOutput(0)->getDimensions();
assert(mOutputDims.nbDims == 2);
return true;
}
有点长,我们一段一段来看:
auto builder = SampleUniquePtr(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
if (!builder) // 创建builder是否成功判断
{
return false;
}
这一句就是用一个在SampleOnnxMNIST类中用using重新命名、定义了自动回收方式的unique_ptr指针 初始化了一个指向nvinfer1::IBuilder
类型数据的指针,名称为builder,该指针指向的内容为nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger())
的返回值。
我们再来一层一层剥洋葱,首先看看这个nvinfer1::Ibuilder是什么,它定义在了./include/Nvinfer.h
文件中,文档中的描述是:从一个网络的定义初始化一个engine,因为类的定义太长了,就不看具体内容了。
nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger())
定义于./include/Nvinfer.h
文件中,函数的描述是创建一个IBuilder类的实例
再看下面:
const auto explicitBatch = 1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
这里的语法看着有点困难,按照从右到左的顺序解读一下:
NetworkDefinitionCreationFlag::kEXPLICIT_BATCH=0
是一个枚举类型值,头文件中的解释是“Mark the network to be an explicit batch network”,即“标记这个网络是一个显式的批处理网络”,它的作用是如果网络的输入维度在运行时是变化的,那么需要把网络设定成这种"explicit batch network"。还有另外一个网络类型的定义为NetworkDefinitionCreationFlag::kEXPLICIT_PRECISION=1
,这个似乎是用于权重已经被量化到[-127,127]的网络,所以当网络设定为这种类型时,builder不会量化网络中的任何权重。一般而言用不到这种类型的吧,所以具体的定义参见./include/NvInfer.h
中对这个枚举类型的定义就行。
然后是static_cast
,它是把设定的网络flag的枚举值强制转换为uint32_t类型,接下来是1U << ...
,因为网络构建的flag只有0和1两种,所以这句是把1左移0位或左移1为,对应二进制01
或10
,最后用这个值初始化了explicitBatch常量。
auto network = SampleUniquePtr(builder->createNetworkV2(explicitBatch));
if (!network)
{
return false;
}
用刚才设定的explicitBatch的网络flag创建了一个网络定义
auto config = SampleUniquePtr(builder->createBuilderConfig());
if (!config)
{
return false;
}
创建了一个config,tensorrt的docs对IBuilderConfig的解释是:The IBuilderConfig has many properties that you can set in order to control such things as the precision at which the network should run, and autotuning parameters such as how many times TensorRT should time each kernel when ascertaining which is fastest (more iterations lead to longer runtimes, but less susceptibility to noise.) You can also query the builder to find out what reduced precision types are natively supported by the hardware.
auto parser
= SampleUniquePtr(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger()));
if (!parser)
{
return false;
}
创建了一个解析器parser
auto constructed = constructNetwork(builder, network, config, parser);
if (!constructed)
{
return false;
}
用前面创建的builder、network、config、parser共同构建网络,又出现一头洋葱函数,现在开始剥洋葱~
//!
//! \brief Uses a ONNX parser to create the Onnx MNIST Network and marks the
//! output layers
//!
//! \param network Pointer to the network that will be populated with the Onnx MNIST network
//!
//! \param builder Pointer to the engine builder
//!
bool SampleOnnxMNIST::constructNetwork(SampleUniquePtr& builder,
SampleUniquePtr& network, SampleUniquePtr& config,
SampleUniquePtr& parser)
{
auto parsed = parser->parseFromFile(locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(),
static_cast(sample::gLogger.getReportableSeverity()));
if (!parsed)
{
return false;
}
config->setMaxWorkspaceSize(16_MiB);
if (mParams.fp16)
{
config->setFlag(BuilderFlag::kFP16);
}
if (mParams.int8)
{
config->setFlag(BuilderFlag::kINT8);
samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f);
}
samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);
return true;
}
这个也是定义在SampleOnnxMNIST类里的函数,看函数的描述为“用一个ONNX parser创建一个Onnx MNIST网络,并且标记其输出层”
auto parsed = parser->parseFromFile(locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(),
static_cast(sample::gLogger.getReportableSeverity()));
if (!parsed)
{
return false;
}
套娃又来了,来,干她
/** \brief Parse an onnx model file, can be a binary protobuf or a text onnx model
* calls parse method inside.
*
* \param File name
* \param Verbosity Level
*
* \return true if the model was parsed successfully
*
*/
virtual bool parseFromFile(const char* onnxModelFile, int verbosity) = 0;
这回跑到TensorRT的头文件里了./include/NvOnnxParser.h
,没劲(因为TensorRT不开源),那我们就远观一下吧
描述是用来解析onnx文件的,这就有意思了,这可是程序里很重要的内容啊。我们的代码里传给它的两个参数分别是locateFile(mParams.onnxFileName, mParams.dataDirs).c_str()
和static_cast
先看locateFile(mParams.onnxFileName, mParams.dataDirs).c_str()
,光从名字上就可以看出它的作用是寻找onnx文件位置的,它的主要作用就是把参数一onnx文件名接到参数二的目录后面, 组成onnx文件的完整的路径并返回这个路径,代码实现上考虑的情况众多,就不展开细讲了,感兴趣的话请看common.h
文件中locateFile函数的具体代码(这种小功能倒是开源了)。mParams是SampleOnnxMNIST中的成员变量,在该类的构造函数中用成员列表初始化的方式进行初始化的,在前面内容中已经讲过了。
回到build函数里,继续看下面:
config->setMaxWorkspaceSize(16_MiB);
if (mParams.fp16)
{
config->setFlag(BuilderFlag::kFP16);
}
if (mParams.int8)
{
config->setFlag(BuilderFlag::kINT8);
samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f);
}
samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);
config->setMaxWorkspaceSize(16_MiB);
设置了最大工作空间的大小,TensorRT的docs中对最大工作空间的含义解释是这样的:Layer algorithms often require temporary workspace. This parameter limits the maximum size that any layer in the network can use. If an insufficient scratch is provided, it is possible that TensorRT may not be able to find an implementation for a given layer.
后面在判断是以FP16运行还是INT8运行的时候调用了config->setFlag,如果是以INT8运行,还得调用一句samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f);
这个函数的代码也挺有意思的,虽然不算是重点,但我们也来看一眼:
inline void setAllTensorScales(INetworkDefinition* network, float inScales = 2.0f, float outScales = 4.0f)
{
// Ensure that all layer inputs have a scale.
for (int i = 0; i < network->getNbLayers(); i++)
{
auto layer = network->getLayer(i);
for (int j = 0; j < layer->getNbInputs(); j++)
{
ITensor* input{layer->getInput(j)};
// Optional inputs are nullptr here and are from RNN layers.
if (input != nullptr && !input->dynamicRangeIsSet())
{
ASSERT(input->setDynamicRange(-inScales, inScales));
}
}
}
// Ensure that all layer outputs have a scale.
// Tensors that are also inputs to layers are ingored here
// since the previous loop nest assigned scales to them.
for (int i = 0; i < network->getNbLayers(); i++)
{
auto layer = network->getLayer(i);
for (int j = 0; j < layer->getNbOutputs(); j++)
{
ITensor* output{layer->getOutput(j)};
// Optional outputs are nullptr here and are from RNN layers.
if (output != nullptr && !output->dynamicRangeIsSet())
{
// Pooling must have the same input and output scales.
if (layer->getType() == LayerType::kPOOLING)
{
ASSERT(output->setDynamicRange(-inScales, inScales));
}
else
{
ASSERT(output->setDynamicRange(-outScales, outScales));
}
}
}
}
}
它的主要作用大概是给每个层设定一个scale,可以看到它的主要逻辑是先用条件判断每个层是什么,然后再分别把判定为是input的层和output的层用setDynamicRange函数把DynamicRange设定为[-inscale,inscale]和[-outscale,outscale]。这里的Tensor Scale在docs讲得不是很详细,但感觉可以大致理解成是Tensor值的可变范围吧,比如INT8量化对应的dynamic range就是[-127,127],而且文档中有指出如果要使用INT8量化,则必须显式地设定Tensor Scale。除此之外我们还可以关注一下这几个用法:network->getNbLayers()
可以获得网络中的层数、network->getLayer(i)
可以获得序号i的层的指针,还可以通过ITensor* output{layer->getOutput(j)};
再判断output是否为空指针来确定当前这个layer层是否是输出层,同理改成getInput也可以判断是否为输入层。(注:这里只是个人理解,未经验证,不一定正确)
再回到上一层代码,最后再调用了samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);
设定了开启DLA加速。
至此,我们终于分析完了SampleOnnxMNIST::constructNetwork
这个函数,可以回到上一层build的代码了,我们继续
mEngine = std::shared_ptr(
builder->buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter());
if (!mEngine)
{
return false;
}
这里主要就是创建了一个engine,照我的理解就是network、config的集成,是一个集成好了的、可以运行的东西。
assert(network->getNbInputs() == 1);
mInputDims = network->getInput(0)->getDimensions();
assert(mInputDims.nbDims == 4);
assert(network->getNbOutputs() == 1);
mOutputDims = network->getOutput(0)->getDimensions();
assert(mOutputDims.nbDims == 2);
这几句是用来判断网络的输入和输出的dims是否正确的,因为我们这里的输入是一批图像,维度大概定义为如[B,C,H,W]这种形式,所以要判断构建的网络的输入层维度是不是4,而且是不是只有一个输入入口。输出是10个数字的概率,维度定义为[1,N]这种形式,所以要判断网络的输出层维度是不是2,而且是不是只有一个输出口。
好了好了,终于看完build了,小结一下build作用就是各种加载和设定参数细节,各种解析网络文件,然后搭建网络,最后按照一定的configuration构建了一个engine,engine就是引擎,它是推动TensorRT推理的真正实体。
回到main函数,build完成,东西被造出来了,接下来就要infer了
if (!sample.infer())
{
return sample::gLogger.reportFail(sampleTest);
}
infer里面有什么呢?
//!
//! \brief Runs the TensorRT inference engine for this sample
//!
//! \details This function is the main execution function of the sample. It allocates the buffer,
//! sets inputs and executes the engine.
//!
bool SampleOnnxMNIST::infer()
{
// Create RAII buffer manager object
samplesCommon::BufferManager buffers(mEngine);
auto context = SampleUniquePtr(mEngine->createExecutionContext());
if (!context)
{
return false;
}
// Read the input data into the managed buffers
assert(mParams.inputTensorNames.size() == 1);
if (!processInput(buffers))
{
return false;
}
// Memcpy from host input buffers to device input buffers
buffers.copyInputToDevice();
bool status = context->executeV2(buffers.getDeviceBindings().data());
if (!status)
{
return false;
}
// Memcpy from device output buffers to host output buffers
buffers.copyOutputToHost();
// Verify results
if (!verifyOutput(buffers))
{
return false;
}
return true;
}
我们先大致看一下整体的流程:先创建了一个buffer,然后构建上下文环境,接着把输入图像读取到buffer里,再把buffer的输入从主机端拷贝到设备端,调用上下文执行函数,再从设备端把输出拷贝到主机端,最后再验证结果。
我们再来逐块看:
// Create RAII buffer manager object
samplesCommon::BufferManager buffers(mEngine);
这个就是创建了一个buffer,查了官方注释里RAII的意思是:“资源获取就是初始化”,是C++语言的一种管理资源、避免泄漏的惯用法。C++标准保证任何情况下,已构造的对象最终会销毁,即它的析构函数最终会被调用。简单的说,RAII 的做法是使用一个对象,在其构造时获取资源,在对象生命期控制对资源的访问使之始终保持有效,最后在对象析构的时候释放资源。
auto context = SampleUniquePtr(mEngine->createExecutionContext());
if (!context)
{
return false;
}
这个也很重要,创建了上下文信息,上下文我在CUDA中有遇到过,可以将其理解成一个管理多个对象生命周期的容器
// Read the input data into the managed buffers
assert(mParams.inputTensorNames.size() == 1);
if (!processInput(buffers))
{
return false;
}
这里就是处理输入了,主要是在processInput函数中把图像读取进来,看看processInput函数
//!
//! \brief Reads the input and stores the result in a managed buffer
//!
bool SampleOnnxMNIST::processInput(const samplesCommon::BufferManager& buffers)
{
const int inputH = mInputDims.d[2]; // 获取网络输入层中定义的图像的高和宽
const int inputW = mInputDims.d[3];
// Read a random digit file
srand(unsigned(time(nullptr))); // 设定随机数,用来随机读取一张图像
std::vector fileData(inputH * inputW); // 创建一个vector存储读入的图像
mNumber = rand() % 10; // 获得从0~9范围内的随机数,选择一张这样的图像作为输入,并传给SampleOnnxMNIST类的mNumber成员变量,作为gt存储着,后面会用来判断预测值和gt是否相同。
readPGMFile(locateFile(std::to_string(mNumber) + ".pgm", mParams.dataDirs), fileData.data(), inputH, inputW); // 用官方common实现的readPGMFile函数读取一张图像到fileData中
// Print an ascii representation 用ASCII码把数字的图片打印到终端
sample::gLogInfo << "Input:" << std::endl;
for (int i = 0; i < inputH * inputW; i++)
{
sample::gLogInfo << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % inputW) ? "" : "\n");
}
sample::gLogInfo << std::endl;
// 把数字填充到buffer中input的相应位置
float* hostDataBuffer = static_cast(buffers.getHostBuffer(mParams.inputTensorNames[0]));
for (int i = 0; i < inputH * inputW; i++)
{
hostDataBuffer[i] = 1.0 - float(fileData[i] / 255.0);
}
return true;
}
这段代码我们只取其中一部分来具体讲解,其他部分都已加上注释了
readPGMFile(locateFile(std::to_string(mNumber) + ".pgm", mParams.dataDirs), fileData.data(), inputH, inputW);
这段代码的作用是把随机选择的一个数字的pgm图像加载出来,具体函数定义如下:
inline void readPGMFile(const std::string& fileName, uint8_t* buffer, int inH, int inW)
{
std::ifstream infile(fileName, std::ifstream::binary);
assert(infile.is_open() && "Attempting to read from a file that is not open.");
std::string magic, h, w, max;
infile >> magic >> h >> w >> max;
infile.seekg(1, infile.cur);
infile.read(reinterpret_cast(buffer), inH * inW);
}
前面提到fileData是定义的一个vector,而.data()的功能就是获取这个vector的首个元素的地址。看起来也就是根据PGM文件的格式,把图像加载到fileData.data()中而已,不难,但这种处理输入的方式值得借鉴,对于其他类型的图像采用类似的方法也是适用的。
再看下面:
float* hostDataBuffer = static_cast(buffers.getHostBuffer(mParams.inputTensorNames[0]));
for (int i = 0; i < inputH * inputW; i++)
{
hostDataBuffer[i] = 1.0 - float(fileData[i] / 255.0); // 原始图像是8位黑白图像,且是白底黑字的,将它转换到0~1且是黑底白字。
}
首先用float* hostDataBuffer = static_cast
获取指向buffers中inputTensorName的内存区域的指针。因为我们在网络中定义的输入大小和读取的图像大小是一致的,所以图像中的数据可以正好填满这个buffers中事先开辟好留给输入的内存空间。
我们再回到infer函数后面部分
// Memcpy from host input buffers to device input buffers
buffers.copyInputToDevice();
bool status = context->executeV2(buffers.getDeviceBindings().data());
if (!status)
{
return false;
}
// Memcpy from device output buffers to host output buffers
buffers.copyOutputToHost();
这个就是异构编程都会涉及的一步了,把数据从主机端拷贝到设备端,在设备端执行运算,然后把结果再从设备端拷贝到主机端。
再看下面
// Verify results
if (!verifyOutput(buffers))
{
return false;
}
接下来就是验证结果是否正确了,这部分不是TensorRT推理程序的必须动作,但在这个官方例程中是加上这部分,验证推理结构是否正确的,简单看一下这个verifyOutput函数
//!
//! \brief Classifies digits and verify result
//!
//! \return whether the classification output matches expectations
//!
bool SampleOnnxMNIST::verifyOutput(const samplesCommon::BufferManager& buffers)
{
const int outputSize = mOutputDims.d[1]; /// 获得网络的输出层总共有多少个输出(即多少类)
float* output = static_cast(buffers.getHostBuffer(mParams.outputTensorNames[0])); // 获取存储在buffers中的输出结果
float val{0.0f};
int idx{0};
// Calculate Softmax 把输出用softmax转换成置信度,并打印出来
float sum{0.0f};
for (int i = 0; i < outputSize; i++)
{
output[i] = exp(output[i]);
sum += output[i];
}
sample::gLogInfo << "Output:" << std::endl;
for (int i = 0; i < outputSize; i++)
{
output[i] /= sum;
val = std::max(val, output[i]);
if (val == output[i])
{
idx = i;
}
sample::gLogInfo << " Prob " << i << " " << std::fixed << std::setw(5) << std::setprecision(4) << output[i]
<< " "
<< "Class " << i << ": " << std::string(int(std::floor(output[i] * 10 + 0.5f)), '*')
<< std::endl;
}
sample::gLogInfo << std::endl;
return idx == mNumber && val > 0.9f; // 如果预测结果和实际相同,并且置信度大于0.9,则返回true
}
太棒了,infer()函数也被扒光看个透彻了。
我们看看main函数里还有什么:
return sample::gLogger.reportPass(sampleTest);
我去,就这?写了一行成功的日志,然后退出了!
至此,我们对TensorRT官方例程sampleOnnxMNIST终于解读完毕,相信收获也是非常多的。看起来冗长复杂的TensorRT代码也不过这样,世上无难事只要肯钻研。现在我们了解了它,下一章我们要驯服它,魔改sampleOnnxMNIST,让它不是仅仅能识别几个无聊的数字,还能识别ImageNet数据集中一千类的东西!