[tensorRT中基于sampleMNIST的算法选择API使用示例]
#include "common/argsParser.h"
#include "common/buffers.h"
#include "common/common.h"
#include "common/logger.h"
#include "NvCaffeParser.h"
#include "NvInfer.h"
#include
#include
#include
#include
#include
#include
#include
#include
const std::string gSampleName = "TensorRT.sample_algorithm_selector";
const std::string gCacheFileName = "AlgorithmCache.txt";
class AlgorithmCacheWriter : public IAlgorithmSelector
{
public:
AlgorithmCacheWriter(const std::string& cacheFileName)
:mCacheFileName(cacheFileName)
{
}
int32_t selectAlgorithms(const nvinfer1::IAlgorithmContext& context,
const nvinfer1::IAlgorithm* const* choices,
int32_t nbChoices, int32_t* selection) override
{
assert(nbChoices>0);
std::iota(selection, selection+nbChoices,0);
// std::vector words(8);
// std::iota(std::begin(words),std::end(words),"mysterious");
// std::copy(std::begin(words),std::end(words),std::ostream_iterator{std::cout," "});
// std::cout<getName() <<"\n";
algorithmFile << algoChoices[i]->getAlgorithmVariant().getImplementation() <<"\n";
algorithmFile << algoChoices[i]->getAlgorithmVariant().getTactic() <<"\n";
const int32_t nbInputs = algoContexts[i]->getNbInputs();
algorithmFile << nbInputs<<"\n";
const int32_t nbOutputs = algoContexts[i]->getNbOutputs();
algorithmFile <(algoChoices[i]->getAlgorithmIOInfo(j).getTensorFormat()) <<"\n";
algorithmFile << static_cast(algoChoices[i]->getAlgorithmIOInfo(j).getDataType()) <<"\n";
}
}
algorithmFile.close();
}
private:
std::string mCacheFileName;
};
class AlgorithmCacheReader : public IAlgorithmSelector
{
public:
int32_t selectAlgorithms(const nvinfer1::IAlgorithmContext& algoContext,
const nvinfer1::IAlgorithm* const* algoChoices,
int32_t nbChoices, int32_t* selection) override
{
assert(nbChoices >0);
const std::string layerName(algoContext.getName());
auto it = choiceMap.find(layerName);
assert(it != choiceMap.end());
auto& algoItem = it->second;
assert(algoItem.nbInputs == algoContext.getNbInputs());
assert(algoItem.nbOutputs == algoContext.getNbOutputs());
int32_t nbSelections = 0;
for(auto i=0;igetName());
assert(choiceMap.find(layerName) != choiceMap.end());
const auto& algoItem = choiceMap[layerName];
assert(algoItem.nbInputs == algoContexts[i]->getNbInputs());
assert(algoItem.nbOutputs == algoContexts[i]->getNbOutputs());
assert(algoChoices[i]->getAlgorithmVariant().getImplementation() == algoItem.implementation);
assert(algoChoices[i]->getAlgorithmVariant().getTactic() == algoItem.tactic);
auto nbFormats = algoItem.nbInputs + algoItem.nbOutputs;
for (auto j = 0; j < nbFormats; j++)
{
assert(algoItem.formats[j].first
== static_cast(algoChoices[i]->getAlgorithmIOInfo(j).getTensorFormat()));
assert(algoItem.formats[j].second
== static_cast(algoChoices[i]->getAlgorithmIOInfo(j).getDataType()));
}
}
}
AlgorithmCacheReader(const std::string& cacheFileName)
{
std::ifstream algorithmFile(cacheFileName);
if (!algorithmFile.good())
{
sample::gLogError << "Cannot open algorithm cache file: " << cacheFileName << " to read." << std::endl;
abort();
}
std::string line;
while (getline(algorithmFile, line))
{
std::string layerName;
layerName = line;
AlgorithmCacheItem algoItem;
getline(algorithmFile, line);
algoItem.implementation = std::stoll(line);
getline(algorithmFile, line);
algoItem.tactic = std::stoll(line);
getline(algorithmFile, line);
algoItem.nbInputs = std::stoi(line);
getline(algorithmFile, line);
algoItem.nbOutputs = std::stoi(line);
const int32_t nbFormats = algoItem.nbInputs + algoItem.nbOutputs;
algoItem.formats.resize(nbFormats);
for (int32_t i = 0; i < nbFormats; i++)
{
getline(algorithmFile, line);
algoItem.formats[i].first = std::stoi(line);
getline(algorithmFile, line);
algoItem.formats[i].second = std::stoi(line);
}
choiceMap[layerName] = std::move(algoItem);
}
algorithmFile.close();
}
private:
struct AlgorithmCacheItem
{
int64_t implementation;
int64_t tactic;
int32_t nbInputs;
int32_t nbOutputs;
std::vector> formats;
};
std::unordered_map choiceMap;
static bool areSame(const AlgorithmCacheItem& algoCacheItem, const IAlgorithm& algoChoice)
{
if(algoChoice.getAlgorithmVariant().getImplementation() != algoCacheItem.implementation
|| algoChoice.getAlgorithmVariant().getTactic() != algoCacheItem.tactic)
{
return false;
}
const auto nbFormats = algoCacheItem.nbInputs + algoCacheItem.nbOutputs;
for(auto j=0;j(algoChoice.getAlgorithmIOInfo(j).getTensorFormat())
|| algoCacheItem.formats[j].second != static_cast(algoChoice.getAlgorithmIOInfo(j).getDataType()))
{
return false;
}
}
return true;
}
};
class MinimumWorkspaceAlgorithmSelector : public IAlgorithmSelector
{
public:
int32_t selectAlgorithms(const nvinfer1::IAlgorithmContext& algoContext,
const nvinfer1::IAlgorithm* const* algoChoices,
int32_t nbChoices, int32_t* selection) override
{
assert(nbChoices > 0);
auto it = std::min_element(
algoChoices, algoChoices+nbChoices,[](const nvinfer1::IAlgorithm* x, const nvinfer1::IAlgorithm* y){
return x->getWorkspaceSize() < y->getWorkspaceSize();
});
selection[0] = static_cast(it-algoChoices);
return 1;
}
void reportAlgorithms(const nvinfer1::IAlgorithmContext* const* algoContexts,
const nvinfer1::IAlgorithm* const* algoChoices, int32_t nbAlgorithms) override
{
}
};
class SampleAlgorithmSelector
{
template
using SampleUniquePtr = std::unique_ptr;
public:
SampleAlgorithmSelector(const samplesCommon::CaffeSampleParams& params)
:mParams(params)
{
}
bool build(IAlgorithmSelector* selector);
bool infer();
bool teardown();
private:
bool constructNetwork(SampleUniquePtr& parser, SampleUniquePtr& network);
bool processInput(const samplesCommon::BufferManager& buffers, const std::string& inputTensorName, int inputFileIndex) const;
bool verifyOutput(const samplesCommon::BufferManager& buffers, const std::string& outputTensorName, int groundTruthDigit) const;
std::shared_ptr mEngine{nullptr};
samplesCommon::CaffeSampleParams mParams;
nvinfer1::Dims mInputDims;
SampleUniquePtr mMeanBlob;
};
bool SampleAlgorithmSelector::build(IAlgorithmSelector *selector)
{
auto builder = SampleUniquePtr(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
if(!builder)
{
return false;
}
auto network = SampleUniquePtr(builder->createNetwork());
if(!network)
{
return false;
}
auto config = SampleUniquePtr(builder->createBuilderConfig());
if(!config)
{
return false;
}
auto parser = SampleUniquePtr(nvcaffeparser1::createCaffeParser());
if(!parser)
{
return false;
}
if(!constructNetwork(parser, network))
{
return false;
}
builder->setMaxBatchSize(mParams.batchSize);
config->setMaxWorkspaceSize(16_MiB);
config->setAlgorithmSelector(selector);
config->setFlag(BuilderFlag::kGPU_FALLBACK);
if(!mParams.int8)
{
config->setFlag(BuilderFlag::kSTRICT_TYPES);
}
if(mParams.fp16)
{
config->setFlag(BuilderFlag::kFP16);
}
if(mParams.int8)
{
config->setFlag(BuilderFlag::kINT8);
}
samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);
mEngine = std::shared_ptr(
builder->buildEngineWithConfig(*network, *config),samplesCommon::InferDeleter());
if(!mEngine)
{
return false;
}
assert(network->getNbInputs() == 1);
mInputDims = network->getInput(0)->getDimensions();
assert(mInputDims.nbDims == 3);
return true;
}
bool SampleAlgorithmSelector::processInput(const samplesCommon::BufferManager &buffers,
const std::string &inputTensorName,
int inputFileIndex) const
{
const int inputH = mInputDims.d[1];
const int inputW = mInputDims.d[2];
srand(unsigned(time(nullptr)));
std::vector fileData(inputH*inputW);
readPGMFile(locateFile(std::to_string(inputFileIndex)+".pgm",mParams.dataDirs), fileData.data(), inputH, inputW);
// Print ASCII representation of digit.
sample::gLogInfo << "Input:\n";
for (int i = 0; i < inputH * inputW; i++)
{
sample::gLogInfo << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % inputW) ? "" : "\n");
}
sample::gLogInfo << std::endl;
float* hostInputBuffer = static_cast(buffers.getHostBuffer(inputTensorName));
for(int i=0;i(buffers.getHostBuffer(outputTensorName));
sample::gLogInfo << "Output:\n";
float val{0.0f};
int idx{0};
const int kDIGITS = 10;
for (int i = 0; i < kDIGITS; i++)
{
if (val < prob[i])
{
val = prob[i];
idx = i;
}
sample::gLogInfo << i << ": " << std::string(int(std::floor(prob[i] * 10 + 0.5f)), '*') << "\n";
}
sample::gLogInfo << std::endl;
return (idx == groundTruthDigit && val > 0.9f);
}
bool SampleAlgorithmSelector::constructNetwork(SampleUniquePtr &parser,
SampleUniquePtr &network)
{
const nvcaffeparser1::IBlobNameToTensor* blobNameToTensor = parser->parse(
mParams.prototxtFileName.c_str(), mParams.weightsFileName.c_str(), *network, nvinfer1::DataType::kFLOAT);
for(auto& s: mParams.outputTensorNames)
{
network->markOutput(*blobNameToTensor->find(s.c_str()));
}
nvinfer1::Dims inputDims = network->getInput(0)->getDimensions();
mMeanBlob = SampleUniquePtr(parser->parseBinaryProto(mParams.meanFileName.c_str()));
nvinfer1::Weights meanWeights{nvinfer1::DataType::kFLOAT, mMeanBlob->getData(), inputDims.d[1]*inputDims.d[1]};
float maxMean = samplesCommon::getMaxValue(static_cast(meanWeights.values),samplesCommon::volume(inputDims));
auto mean = network->addConstant(nvinfer1::Dims3(1, inputDims.d[1], inputDims.d[2]), meanWeights);
if(!mean->getOutput(0)->setDynamicRange(-maxMean, maxMean))
{
return false;
}
if(!network->getInput(0)->setDynamicRange(-maxMean,maxMean))
{
return false;
}
auto meanSub = network->addElementWise(*network->getInput(0),*mean->getOutput(0), ElementWiseOperation::kSUB);
if(!meanSub->getOutput(0)->setDynamicRange(-maxMean,maxMean))
{
return false;
}
network->getLayer(0)->setInput(0,*meanSub->getOutput(0));
samplesCommon::setAllTensorScales(network.get(),127.0f,127.0f);
return true;
}
bool SampleAlgorithmSelector::infer()
{
// Create RAII buffer manager object.
samplesCommon::BufferManager buffers(mEngine, mParams.batchSize);
auto context = SampleUniquePtr(mEngine->createExecutionContext());
if (!context)
{
return false;
}
// Pick a random digit to try to infer.
srand(time(NULL));
const int digit = rand() % 10;
// Read the input data into the managed buffers.
// There should be just 1 input tensor.
assert(mParams.inputTensorNames.size() == 1);
if (!processInput(buffers, mParams.inputTensorNames[0], digit))
{
return false;
}
// Create CUDA stream for the execution of this inference.
cudaStream_t stream;
CHECK(cudaStreamCreate(&stream));
// Asynchronously copy data from host input buffers to device input buffers
buffers.copyInputToDeviceAsync(stream);
// Asynchronously enqueue the inference work
if (!context->enqueue(mParams.batchSize, buffers.getDeviceBindings().data(), stream, nullptr))
{
return false;
}
// Asynchronously copy data from device output buffers to host output buffers.
buffers.copyOutputToHostAsync(stream);
// Wait for the work in the stream to complete.
cudaStreamSynchronize(stream);
// Release stream.
cudaStreamDestroy(stream);
// Check and print the output of the inference.
// There should be just one output tensor.
assert(mParams.outputTensorNames.size() == 1);
bool outputCorrect = verifyOutput(buffers, mParams.outputTensorNames[0], digit);
return outputCorrect;
}
bool SampleAlgorithmSelector::teardown()
{
nvcaffeparser1::shutdownProtobufLibrary();
return true;
}
samplesCommon::CaffeSampleParams initializeSampleParams(const samplesCommon::Args& args)
{
samplesCommon::CaffeSampleParams params;
if (args.dataDirs.empty()) //!< Use default directories if user hasn't provided directory paths.
{
params.dataDirs.push_back("data/");
}
else //!< Use the data directory provided by the user.
{
params.dataDirs = args.dataDirs;
}
params.prototxtFileName = locateFile("mnist.prototxt", params.dataDirs);
params.weightsFileName = locateFile("mnist.caffemodel", params.dataDirs);
params.meanFileName = locateFile("mnist_mean.binaryproto", params.dataDirs);
params.inputTensorNames.push_back("data");
params.batchSize = 1;
params.outputTensorNames.push_back("prob");
params.dlaCore = args.useDLACore;
params.int8 = args.runInInt8;
params.fp16 = args.runInFp16;
return params;
}
void printHelpInfo()
{
std::cout << "Usage: ./sample_algorithm_selector [-h or --help] [-d or --datadir=] "
"[--useDLACore=]\n";
std::cout << "--help Display help information\n";
std::cout << "--datadir Specify path to a data directory, overriding the default. This option can be used "
"multiple times to add multiple directories. If no data directories are given, the default is to use "
"(data/samples/mnist/, data/mnist/)"
<< std::endl;
std::cout << "--useDLACore=N Specify a DLA engine for layers that support DLA. Value can range from 0 to n-1, "
"where n is the number of DLA engines on the platform."
<< std::endl;
std::cout << "--int8 Run in Int8 mode.\n";
std::cout << "--fp16 Run in FP16 mode.\n";
}
int main(int argc, char** argv)
{
samplesCommon::Args args;
bool argsOK = samplesCommon::parseArgs(args, argc, argv);
if (!argsOK)
{
sample::gLogError << "Invalid arguments" << std::endl;
printHelpInfo();
return EXIT_FAILURE;
}
if (args.help)
{
printHelpInfo();
return EXIT_SUCCESS;
}
auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);
sample::gLogger.reportTestStart(sampleTest);
samplesCommon::CaffeSampleParams params = initializeSampleParams(args);
// Write Algorithm Cache.
SampleAlgorithmSelector sampleAlgorithmSelector(params);
{
sample::gLogInfo << "Building and running a GPU inference engine for MNIST." << std::endl;
sample::gLogInfo << "Writing Algorithm Cache for MNIST." << std::endl;
AlgorithmCacheWriter algorithmCacheWriter(gCacheFileName);
if (!sampleAlgorithmSelector.build(&algorithmCacheWriter))
{
return sample::gLogger.reportFail(sampleTest);
}
if (!sampleAlgorithmSelector.infer())
{
return sample::gLogger.reportFail(sampleTest);
}
}
{
// Build network using Cache from previous run.
sample::gLogInfo << "Building a GPU inference engine for MNIST using Algorithm Cache." << std::endl;
AlgorithmCacheReader algorithmCacheReader(gCacheFileName);
if (!sampleAlgorithmSelector.build(&algorithmCacheReader))
{
return sample::gLogger.reportFail(sampleTest);
}
if (!sampleAlgorithmSelector.infer())
{
return sample::gLogger.reportFail(sampleTest);
}
}
{
// Build network using MinimumWorkspaceAlgorithmSelector.
sample::gLogInfo
<< "Building a GPU inference engine for MNIST using Algorithms with minimum workspace requirements."
<< std::endl;
MinimumWorkspaceAlgorithmSelector minimumWorkspaceAlgorithmSelector;
if (!sampleAlgorithmSelector.build(&minimumWorkspaceAlgorithmSelector))
{
return sample::gLogger.reportFail(sampleTest);
}
if (!sampleAlgorithmSelector.infer())
{
return sample::gLogger.reportFail(sampleTest);
}
}
if (!sampleAlgorithmSelector.teardown())
{
return sample::gLogger.reportFail(sampleTest);
}
return sample::gLogger.reportPass(sampleTest);
}