这个示例sampleMNIST是一个简单的hello world示例,它执行了基本的安装工作,使用Caffe解析器初始化TensorRT。这是一个用C++完成tensorRT加速的示例。
这个示例使用已经在MNIST dataset数据集训练好的Caffe模型。
Activation layer 激活层
Convolution layer 卷积层
FullyConnected layer 全连接层
Pooling layer 池化层
Scale layer Scale 层
SoftMax layer SoftMax 层
该示例需要读取Caffe 的三个文件来建立网络,加速和推理。
int main(int argc, char** argv)
samplesCommon::Args args;
bool argsOK = samplesCommon::parseArgs(args, argc, argv);
if (args.help)
if (!argsOK)
gLogError << "Invalid arguments" << std::endl;
auto sampleTest = gLogger.defineTest(gSampleName, argc, const_cast<const char**>(argv));
MNISTSampleParams params = initializeSampleParams(args);//初始化参数
SampleMNIST sample(params);//建立对象
gLogInfo << "Building and running a GPU inference engine for MNIST" << std::endl;
if (!sample.build())//建立引擎
return gLogger.reportFail(sampleTest);
if (!sample.infer())//开始推理
return gLogger.reportFail(sampleTest);
if (!sample.teardown())//释放资源
return gLogger.reportFail(sampleTest);
return gLogger.reportPass(sampleTest);
class SampleMNIST
template <typename T>
using SampleUniquePtr = std::unique_ptr<T, samplesCommon::InferDeleter>;
SampleMNIST(const MNISTSampleParams& params)
: mParams(params)
//! \brief Function builds the network engine
bool build();
//! \brief This function runs the TensorRT inference engine for this sample
bool infer();
//! \brief This function can be used to clean up any state created in the sample class
bool teardown();
//! \brief This function uses a Caffe parser to create the MNIST Network and marks the
//! output layers
void constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder, SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvcaffeparser1::ICaffeParser>& parser);
//! \brief Reads the input and mean data, preprocesses, and stores the result in a managed buffer
bool processInput(const samplesCommon::BufferManager& buffers, const std::string& inputTensorName, int inputFileIdx) const;
//! \brief Verifies that the output is correct and prints it
bool verifyOutput(const samplesCommon::BufferManager& buffers, const std::string& outputTensorName, int groundTruthDigit) const;
std::shared_ptr<nvinfer1::ICudaEngine> mEngine = nullptr; //!< The TensorRT engine used to run the network
MNISTSampleParams mParams; //!< The parameters for the sample.
nvinfer1::Dims mInputDims; //!< The dimensions of the input to the network.
SampleUniquePtr<nvcaffeparser1::IBinaryProtoBlob> mMeanBlob; //! the mean blob, which we need to keep around until build is done
//! \brief This function creates the network, configures the builder and creates the network engine
//! \details This function creates the MNIST network by parsing the caffe model and builds
//! the engine that will be used to run MNIST (mEngine)
//! \return Returns true if the engine was created successfully and false otherwise
bool SampleMNIST::build()
auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(gLogger.getTRTLogger()));
if (!builder)
return false;
auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetwork());
if (!network)
return false;
auto parser = SampleUniquePtr<nvcaffeparser1::ICaffeParser>(nvcaffeparser1::createCaffeParser());
if (!parser)
return false;
constructNetwork(builder, network, parser);//
samplesCommon::enableDLA(builder.get(), mParams.dlaCore);
mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(builder->buildCudaEngine(*network), samplesCommon::InferDeleter());
if (!mEngine)
return false;
assert(network->getNbInputs() == 1);
mInputDims = network->getInput(0)->getDimensions();//网络输入的大小
assert(mInputDims.nbDims == 3);
return true;
//! \brief Reads the input and mean data, preprocesses, and stores the result in a managed buffer
//! //读取输入图片,预处理,然后存储到buffer里
bool SampleMNIST::processInput(const samplesCommon::BufferManager& buffers, const std::string& inputTensorName, int inputFileIdx) const
const int inputH = mInputDims.d[1];
const int inputW = mInputDims.d[2];
// Read a random digit file
std::vector<uint8_t> fileData(inputH * inputW);
readPGMFile(locateFile(std::to_string(inputFileIdx) + ".pgm", mParams.dataDirs), fileData.data(), inputH, inputW);//读取随机选择的那张图片
// Print ASCII representation of digit
gLogInfo << "Input:\n";
for (int i = 0; i < inputH * inputW; i++)
gLogInfo << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % inputW) ? "" : "\n");
gLogInfo << std::endl;
float* hostInputBuffer = static_cast<float*>(buffers.getHostBuffer(inputTensorName));
for (int i = 0; i < inputH * inputW; i++)
hostInputBuffer[i] = float(fileData[i]);//填充数据到缓存
return true;
//! \brief Verifies that the output is correct and prints it
bool SampleMNIST::verifyOutput(const samplesCommon::BufferManager& buffers, const std::string& outputTensorName, int groundTruthDigit) const
const float* prob = static_cast<const float*>(buffers.getHostBuffer(outputTensorName));
// Print histogram of the output distribution
gLogInfo << "Output:\n";
float val{0.0f};
int idx{0};
for (unsigned int i = 0; i < 10; i++)
val = std::max(val, prob[i]);
if (val == prob[i])
idx = i;
gLogInfo << i << ": " << std::string(int(std::floor(prob[i] * 10 + 0.5f)), '*') << "\n";
gLogInfo << std::endl;
return (idx == groundTruthDigit && val > 0.9f);
//! \brief This function uses a caffe parser to create the MNIST Network and marks the
//! output layers
//! \param network Pointer to the network that will be populated with the MNIST network
//! \param builder Pointer to the engine builder
//!//用caffe parser解析模型文件和配置文件等去创建MNIST Network。
void SampleMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder, SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvcaffeparser1::ICaffeParser>& parser)
const nvcaffeparser1::IBlobNameToTensor* blobNameToTensor = parser->parse(
locateFile(mParams.prototxtFileName, mParams.dataDirs).c_str(),
locateFile(mParams.weightsFileName, mParams.dataDirs).c_str(),
for (auto& s : mParams.outputTensorNames)
// add mean subtraction to the beginning of the network
Dims inputDims = network->getInput(0)->getDimensions();
mMeanBlob = SampleUniquePtr<nvcaffeparser1::IBinaryProtoBlob>(parser->parseBinaryProto(locateFile(mParams.meanFileName, mParams.dataDirs).c_str()));
Weights meanWeights{DataType::kFLOAT, mMeanBlob->getData(), inputDims.d[1] * inputDims.d[2]};
// For this sample, a large range based on the mean data is chosen and applied to the entire network.
// The preferred method is use scales computed based on a representative data set
// and apply each one individually based on the tensor. The range here is large enough for the
// network, but is chosen for example purposes only.
float maxMean = samplesCommon::getMaxValue(static_cast<const float*>(meanWeights.values), samplesCommon::volume(inputDims));
auto mean = network->addConstant(Dims3(1, inputDims.d[1], inputDims.d[2]), meanWeights);
auto meanSub = network->addElementWise(*network->getInput(0), *mean->getOutput(0), ElementWiseOperation::kSUB);
network->getLayer(0)->setInput(0, *meanSub->getOutput(0));
samplesCommon::setAllTensorScales(network.get(), maxMean, maxMean);
//! \brief This function runs the TensorRT inference engine for this sample
//! \details This function is the main execution function of the sample. It allocates
//! the buffer, sets inputs, executes the engine, and verifies the output.
bool SampleMNIST::infer()//推理的代码
// Create RAII buffer manager object
samplesCommon::BufferManager buffers(mEngine, mParams.batchSize);//首先创建buffer
auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
if (!context)
return false;
// Pick a random digit to try to infer
const int digit = rand() % 10;//随机选取一个数字
// Read the input data into the managed buffers
// There should be just 1 input tensor
assert(mParams.inputTensorNames.size() == 1);
if (!processInput(buffers, mParams.inputTensorNames[0], digit))//读取输入数据
return false;
// Create CUDA stream for the execution of this inference.
cudaStream_t stream;
// Asynchronously copy data from host input buffers to device input buffers
// Asynchronously enqueue the inference work
if (!context->enqueue(mParams.batchSize, buffers.getDeviceBindings().data(), stream, nullptr))
return false;//异步执行推理工作
// Asynchronously copy data from device output buffers to host output buffers
// Wait for the work in the stream to complete
// Release stream
// Check and print the output of the inference
// There should be just one output tensor
assert(mParams.outputTensorNames.size() == 1);
bool outputCorrect = verifyOutput(buffers, mParams.outputTensorNames[0], digit);//检查结果
return outputCorrect;
//! \brief This function can be used to clean up any state created in the sample class
bool SampleMNIST::teardown()
//! Clean up the libprotobuf files as the parsing is complete
//! \note It is not safe to use any other part of the protocol buffers library after
//! ShutdownProtobufLibrary() has been called.
return true;
//! \brief This function initializes members of the params struct using the command line args
MNISTSampleParams initializeSampleParams(const samplesCommon::Args& args)
MNISTSampleParams params;
if (args.dataDirs.size() != 0) //!< Use the data directory provided by the user
params.dataDirs = args.dataDirs;
else //!< Use default directories if user hasn't provided directory paths
params.prototxtFileName = "mnist.prototxt";
params.weightsFileName = "mnist.caffemodel";
params.meanFileName = "mnist_mean.binaryproto";
params.batchSize = 1;
params.dlaCore = args.useDLACore;
params.int8 = args.runInInt8;
params.fp16 = args.runInFp16;
return params;
//! \brief This function prints the help information for running this sample
void printHelpInfo()
std::cout << "Usage: ./sample_mnist [-h or --help] [-d or --datadir=] [--useDLACore=]\n" ;
std::cout << "--help Display help information\n";
std::cout << "--datadir Specify path to a data directory, overriding the default. This option can be used multiple times to add multiple directories. If no data directories are given, the default is to use (data/samples/mnist/, data/mnist/)" << std::endl;
std::cout << "--useDLACore=N Specify a DLA engine for layers that support DLA. Value can range from 0 to n-1, where n is the number of DLA engines on the platform." << std::endl;
std::cout << "--int8 Run in Int8 mode.\n";
std::cout << "--fp16 Run in FP16 mode.\n";