Tensorrt_Algorithm Selection API usage example based off sampleMNIST in TensorRT

[tensorRT中基于sampleMNIST的算法选择API使用示例]

1 description

  1. IAlgorithmSelector:使用该插件确切的创建TRT engines
  2. IAlgorithmSelector::selectAlorithms:定义用于selction of algrithm的heuristics(启发性的)

2 run—caffe

  1. 使用tensorrt的caffe parser运行基础setup和initialization
  2. 使用caffe parser读取一个训练过的caffe模型
  3. 在一个manager buffer中进行输入预处理和保存结果
  4. 建立3个算法选择的instances
  5. 建立使用算法选择器建立3个engines
  6. 序列化和解序列化engines
  7. 使用engine在一张图上进行推理

3建立algorithm seletors

  1. AlgorithmCacheWriter-使用AlgorithmSelector::reportAlogrithms 将Tensorrt默认算法选择写到“AlgorithmChoices.txt”文件中
  2. AlgorithmCacheReader-使用IAlgorithmSelector::selectAlgorithms从“AlgrithmChoices.txt”中复制algorithm选项,通过IAlgrothmSelector::reportAlgorithms验证这个选项
  3. MinimumWorkspaceAlgorithmSelector-使用IAlgorithmSelector::selectAlgorithms选择需要最小工作空间的算法

4 run_demo

  1. make
  2. ./sample_algorithm_selector [-h] [–datadir=/path/to/data/dir/] [–useDLA=N] [–fp16 or --int8]

5 code

#include "common/argsParser.h"
#include "common/buffers.h"
#include "common/common.h"
#include "common/logger.h"
#include "NvCaffeParser.h"
#include "NvInfer.h"
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

const std::string gSampleName = "TensorRT.sample_algorithm_selector";
const std::string gCacheFileName = "AlgorithmCache.txt";
class AlgorithmCacheWriter : public IAlgorithmSelector
{
public:
    AlgorithmCacheWriter(const std::string& cacheFileName)
        :mCacheFileName(cacheFileName)
    {

    }
    int32_t selectAlgorithms(const nvinfer1::IAlgorithmContext& context,
                             const nvinfer1::IAlgorithm* const* choices,
                             int32_t nbChoices, int32_t* selection) override
    {
        assert(nbChoices>0);
        std::iota(selection, selection+nbChoices,0);

//        std::vector words(8);
//        std::iota(std::begin(words),std::end(words),"mysterious");
//        std::copy(std::begin(words),std::end(words),std::ostream_iterator{std::cout," "});
//        std::cout<getName() <<"\n";
            algorithmFile << algoChoices[i]->getAlgorithmVariant().getImplementation() <<"\n";
            algorithmFile << algoChoices[i]->getAlgorithmVariant().getTactic() <<"\n";

            const int32_t nbInputs = algoContexts[i]->getNbInputs();
            algorithmFile << nbInputs<<"\n";
            const int32_t nbOutputs = algoContexts[i]->getNbOutputs();
            algorithmFile <(algoChoices[i]->getAlgorithmIOInfo(j).getTensorFormat()) <<"\n";
                algorithmFile << static_cast(algoChoices[i]->getAlgorithmIOInfo(j).getDataType()) <<"\n";
            }
        }
        algorithmFile.close();
    }

private:
    std::string mCacheFileName;
};
class AlgorithmCacheReader : public IAlgorithmSelector
{
public:
    int32_t selectAlgorithms(const nvinfer1::IAlgorithmContext& algoContext,
                             const nvinfer1::IAlgorithm* const* algoChoices,
                             int32_t nbChoices, int32_t* selection) override
    {
        assert(nbChoices >0);
        const std::string layerName(algoContext.getName());
        auto it = choiceMap.find(layerName);
        assert(it != choiceMap.end());
        auto& algoItem = it->second;
        assert(algoItem.nbInputs == algoContext.getNbInputs());
        assert(algoItem.nbOutputs == algoContext.getNbOutputs());
        int32_t nbSelections = 0;
        for(auto i=0;igetName());
            assert(choiceMap.find(layerName) != choiceMap.end());
            const auto& algoItem = choiceMap[layerName];
            assert(algoItem.nbInputs == algoContexts[i]->getNbInputs());
            assert(algoItem.nbOutputs == algoContexts[i]->getNbOutputs());
            assert(algoChoices[i]->getAlgorithmVariant().getImplementation() == algoItem.implementation);
            assert(algoChoices[i]->getAlgorithmVariant().getTactic() == algoItem.tactic);
            auto nbFormats = algoItem.nbInputs + algoItem.nbOutputs;
            for (auto j = 0; j < nbFormats; j++)
            {
                assert(algoItem.formats[j].first
                    == static_cast(algoChoices[i]->getAlgorithmIOInfo(j).getTensorFormat()));
                assert(algoItem.formats[j].second
                    == static_cast(algoChoices[i]->getAlgorithmIOInfo(j).getDataType()));
            }
        }
    }
    AlgorithmCacheReader(const std::string& cacheFileName)
    {
        std::ifstream algorithmFile(cacheFileName);
        if (!algorithmFile.good())
        {
            sample::gLogError << "Cannot open algorithm cache file: " << cacheFileName << " to read." << std::endl;
            abort();
        }

        std::string line;
        while (getline(algorithmFile, line))
        {
            std::string layerName;
            layerName = line;

            AlgorithmCacheItem algoItem;
            getline(algorithmFile, line);
            algoItem.implementation = std::stoll(line);

            getline(algorithmFile, line);
            algoItem.tactic = std::stoll(line);

            getline(algorithmFile, line);
            algoItem.nbInputs = std::stoi(line);

            getline(algorithmFile, line);
            algoItem.nbOutputs = std::stoi(line);

            const int32_t nbFormats = algoItem.nbInputs + algoItem.nbOutputs;
            algoItem.formats.resize(nbFormats);
            for (int32_t i = 0; i < nbFormats; i++)
            {
                getline(algorithmFile, line);
                algoItem.formats[i].first = std::stoi(line);
                getline(algorithmFile, line);
                algoItem.formats[i].second = std::stoi(line);
            }
            choiceMap[layerName] = std::move(algoItem);
        }
        algorithmFile.close();
    }
private:
    struct AlgorithmCacheItem
    {
        int64_t implementation;
        int64_t tactic;
        int32_t nbInputs;
        int32_t nbOutputs;
        std::vector> formats;
    };
    std::unordered_map choiceMap;
    static bool areSame(const AlgorithmCacheItem& algoCacheItem, const IAlgorithm& algoChoice)
    {
        if(algoChoice.getAlgorithmVariant().getImplementation() != algoCacheItem.implementation
                || algoChoice.getAlgorithmVariant().getTactic() != algoCacheItem.tactic)
        {
            return false;
        }
        const auto nbFormats = algoCacheItem.nbInputs + algoCacheItem.nbOutputs;
        for(auto j=0;j(algoChoice.getAlgorithmIOInfo(j).getTensorFormat())
                    || algoCacheItem.formats[j].second != static_cast(algoChoice.getAlgorithmIOInfo(j).getDataType()))
            {
                return false;
            }
        }
        return true;
    }
};

class MinimumWorkspaceAlgorithmSelector : public IAlgorithmSelector
{
public:
    int32_t selectAlgorithms(const nvinfer1::IAlgorithmContext& algoContext,
                             const nvinfer1::IAlgorithm* const* algoChoices,
                             int32_t nbChoices, int32_t* selection) override
    {
        assert(nbChoices > 0);
        auto it = std::min_element(
                    algoChoices, algoChoices+nbChoices,[](const nvinfer1::IAlgorithm* x, const nvinfer1::IAlgorithm* y){
            return x->getWorkspaceSize() < y->getWorkspaceSize();
        });
        selection[0] = static_cast(it-algoChoices);
        return 1;
    }
    void reportAlgorithms(const nvinfer1::IAlgorithmContext* const* algoContexts,
                          const nvinfer1::IAlgorithm* const* algoChoices, int32_t nbAlgorithms) override
    {

    }

};
class SampleAlgorithmSelector
{
    template
    using SampleUniquePtr = std::unique_ptr;
public:
    SampleAlgorithmSelector(const samplesCommon::CaffeSampleParams& params)
        :mParams(params)
    {

    }
    bool build(IAlgorithmSelector* selector);
    bool infer();
    bool teardown();
private:
    bool constructNetwork(SampleUniquePtr& parser, SampleUniquePtr& network);
    bool processInput(const samplesCommon::BufferManager& buffers, const std::string& inputTensorName, int inputFileIndex) const;
    bool verifyOutput(const samplesCommon::BufferManager& buffers, const std::string& outputTensorName, int groundTruthDigit) const;
    std::shared_ptr mEngine{nullptr};
    samplesCommon::CaffeSampleParams mParams;
    nvinfer1::Dims mInputDims;
    SampleUniquePtr mMeanBlob;
};
bool SampleAlgorithmSelector::build(IAlgorithmSelector *selector)
{
    auto builder = SampleUniquePtr(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
    if(!builder)
    {
        return false;
    }
    auto network = SampleUniquePtr(builder->createNetwork());
    if(!network)
    {
        return false;
    }
    auto config = SampleUniquePtr(builder->createBuilderConfig());
    if(!config)
    {
        return false;
    }
    auto parser = SampleUniquePtr(nvcaffeparser1::createCaffeParser());
    if(!parser)
    {
        return false;
    }
    if(!constructNetwork(parser, network))
    {
        return false;
    }
    builder->setMaxBatchSize(mParams.batchSize);
    config->setMaxWorkspaceSize(16_MiB);
    config->setAlgorithmSelector(selector);
    config->setFlag(BuilderFlag::kGPU_FALLBACK);
    if(!mParams.int8)
    {
        config->setFlag(BuilderFlag::kSTRICT_TYPES);
    }
    if(mParams.fp16)
    {
        config->setFlag(BuilderFlag::kFP16);
    }
    if(mParams.int8)
    {
        config->setFlag(BuilderFlag::kINT8);
    }
    samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);
    mEngine = std::shared_ptr(
                builder->buildEngineWithConfig(*network, *config),samplesCommon::InferDeleter());
    if(!mEngine)
    {
        return false;
    }
    assert(network->getNbInputs() == 1);
    mInputDims = network->getInput(0)->getDimensions();
    assert(mInputDims.nbDims == 3);
    return true;

}
bool SampleAlgorithmSelector::processInput(const samplesCommon::BufferManager &buffers,
                                           const std::string &inputTensorName,
                                           int inputFileIndex) const
{
    const int inputH = mInputDims.d[1];
    const int inputW = mInputDims.d[2];

    srand(unsigned(time(nullptr)));
    std::vector fileData(inputH*inputW);
    readPGMFile(locateFile(std::to_string(inputFileIndex)+".pgm",mParams.dataDirs), fileData.data(), inputH, inputW);
    // Print ASCII representation of digit.
    sample::gLogInfo << "Input:\n";
    for (int i = 0; i < inputH * inputW; i++)
    {
        sample::gLogInfo << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % inputW) ? "" : "\n");
    }
    sample::gLogInfo << std::endl;
    float* hostInputBuffer = static_cast(buffers.getHostBuffer(inputTensorName));
    for(int i=0;i(buffers.getHostBuffer(outputTensorName));
    sample::gLogInfo << "Output:\n";
    float val{0.0f};
    int idx{0};
    const int kDIGITS = 10;

    for (int i = 0; i < kDIGITS; i++)
    {
        if (val < prob[i])
        {
            val = prob[i];
            idx = i;
        }

        sample::gLogInfo << i << ": " << std::string(int(std::floor(prob[i] * 10 + 0.5f)), '*') << "\n";
    }
    sample::gLogInfo << std::endl;

    return (idx == groundTruthDigit && val > 0.9f);
}
bool SampleAlgorithmSelector::constructNetwork(SampleUniquePtr &parser,
                                                SampleUniquePtr &network)
{
    const nvcaffeparser1::IBlobNameToTensor* blobNameToTensor = parser->parse(
                mParams.prototxtFileName.c_str(), mParams.weightsFileName.c_str(), *network, nvinfer1::DataType::kFLOAT);
    for(auto& s: mParams.outputTensorNames)
    {
        network->markOutput(*blobNameToTensor->find(s.c_str()));
    }
    nvinfer1::Dims inputDims = network->getInput(0)->getDimensions();
    mMeanBlob = SampleUniquePtr(parser->parseBinaryProto(mParams.meanFileName.c_str()));
    nvinfer1::Weights meanWeights{nvinfer1::DataType::kFLOAT, mMeanBlob->getData(), inputDims.d[1]*inputDims.d[1]};
    float maxMean = samplesCommon::getMaxValue(static_cast(meanWeights.values),samplesCommon::volume(inputDims));
    auto mean = network->addConstant(nvinfer1::Dims3(1, inputDims.d[1], inputDims.d[2]), meanWeights);
    if(!mean->getOutput(0)->setDynamicRange(-maxMean, maxMean))
    {
        return false;
    }
    if(!network->getInput(0)->setDynamicRange(-maxMean,maxMean))
    {
        return false;
    }
    auto meanSub = network->addElementWise(*network->getInput(0),*mean->getOutput(0), ElementWiseOperation::kSUB);
    if(!meanSub->getOutput(0)->setDynamicRange(-maxMean,maxMean))
    {
        return false;
    }
    network->getLayer(0)->setInput(0,*meanSub->getOutput(0));
    samplesCommon::setAllTensorScales(network.get(),127.0f,127.0f);
    return true;
}
bool SampleAlgorithmSelector::infer()
{
    // Create RAII buffer manager object.
    samplesCommon::BufferManager buffers(mEngine, mParams.batchSize);

    auto context = SampleUniquePtr(mEngine->createExecutionContext());
    if (!context)
    {
        return false;
    }

    // Pick a random digit to try to infer.
    srand(time(NULL));
    const int digit = rand() % 10;

    // Read the input data into the managed buffers.
    // There should be just 1 input tensor.
    assert(mParams.inputTensorNames.size() == 1);
    if (!processInput(buffers, mParams.inputTensorNames[0], digit))
    {
        return false;
    }
    // Create CUDA stream for the execution of this inference.
    cudaStream_t stream;
    CHECK(cudaStreamCreate(&stream));

    // Asynchronously copy data from host input buffers to device input buffers
    buffers.copyInputToDeviceAsync(stream);

    // Asynchronously enqueue the inference work
    if (!context->enqueue(mParams.batchSize, buffers.getDeviceBindings().data(), stream, nullptr))
    {
        return false;
    }
    // Asynchronously copy data from device output buffers to host output buffers.
    buffers.copyOutputToHostAsync(stream);

    // Wait for the work in the stream to complete.
    cudaStreamSynchronize(stream);

    // Release stream.
    cudaStreamDestroy(stream);

    // Check and print the output of the inference.
    // There should be just one output tensor.
    assert(mParams.outputTensorNames.size() == 1);
    bool outputCorrect = verifyOutput(buffers, mParams.outputTensorNames[0], digit);

    return outputCorrect;
}
bool SampleAlgorithmSelector::teardown()
{
    nvcaffeparser1::shutdownProtobufLibrary();
    return true;
}
samplesCommon::CaffeSampleParams initializeSampleParams(const samplesCommon::Args& args)
{
    samplesCommon::CaffeSampleParams params;
    if (args.dataDirs.empty()) //!< Use default directories if user hasn't provided directory paths.
    {
        params.dataDirs.push_back("data/");
    }
    else //!< Use the data directory provided by the user.
    {
        params.dataDirs = args.dataDirs;
    }

    params.prototxtFileName = locateFile("mnist.prototxt", params.dataDirs);
    params.weightsFileName = locateFile("mnist.caffemodel", params.dataDirs);
    params.meanFileName = locateFile("mnist_mean.binaryproto", params.dataDirs);
    params.inputTensorNames.push_back("data");
    params.batchSize = 1;
    params.outputTensorNames.push_back("prob");
    params.dlaCore = args.useDLACore;
    params.int8 = args.runInInt8;
    params.fp16 = args.runInFp16;

    return params;
}
void printHelpInfo()
{
    std::cout << "Usage: ./sample_algorithm_selector [-h or --help] [-d or --datadir=] "
                 "[--useDLACore=]\n";
    std::cout << "--help          Display help information\n";
    std::cout << "--datadir       Specify path to a data directory, overriding the default. This option can be used "
                 "multiple times to add multiple directories. If no data directories are given, the default is to use "
                 "(data/samples/mnist/, data/mnist/)"
              << std::endl;
    std::cout << "--useDLACore=N  Specify a DLA engine for layers that support DLA. Value can range from 0 to n-1, "
                 "where n is the number of DLA engines on the platform."
              << std::endl;
    std::cout << "--int8          Run in Int8 mode.\n";
    std::cout << "--fp16          Run in FP16 mode.\n";
}
int main(int argc, char** argv)
{
    samplesCommon::Args args;
    bool argsOK = samplesCommon::parseArgs(args, argc, argv);
    if (!argsOK)
    {
        sample::gLogError << "Invalid arguments" << std::endl;
        printHelpInfo();
        return EXIT_FAILURE;
    }
    if (args.help)
    {
        printHelpInfo();
        return EXIT_SUCCESS;
    }

    auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);

    sample::gLogger.reportTestStart(sampleTest);

    samplesCommon::CaffeSampleParams params = initializeSampleParams(args);
    // Write Algorithm Cache.
    SampleAlgorithmSelector sampleAlgorithmSelector(params);

    {
        sample::gLogInfo << "Building and running a GPU inference engine for MNIST." << std::endl;
        sample::gLogInfo << "Writing Algorithm Cache for MNIST." << std::endl;
        AlgorithmCacheWriter algorithmCacheWriter(gCacheFileName);

        if (!sampleAlgorithmSelector.build(&algorithmCacheWriter))
        {
            return sample::gLogger.reportFail(sampleTest);
        }

        if (!sampleAlgorithmSelector.infer())
        {
            return sample::gLogger.reportFail(sampleTest);
        }
    }

    {
        // Build network using Cache from previous run.
        sample::gLogInfo << "Building a GPU inference engine for MNIST using Algorithm Cache." << std::endl;
        AlgorithmCacheReader algorithmCacheReader(gCacheFileName);

        if (!sampleAlgorithmSelector.build(&algorithmCacheReader))
        {
            return sample::gLogger.reportFail(sampleTest);
        }

        if (!sampleAlgorithmSelector.infer())
        {
            return sample::gLogger.reportFail(sampleTest);
        }
    }

    {
        // Build network using MinimumWorkspaceAlgorithmSelector.
        sample::gLogInfo
            << "Building a GPU inference engine for MNIST using Algorithms with minimum workspace requirements."
            << std::endl;
        MinimumWorkspaceAlgorithmSelector minimumWorkspaceAlgorithmSelector;
        if (!sampleAlgorithmSelector.build(&minimumWorkspaceAlgorithmSelector))
        {
            return sample::gLogger.reportFail(sampleTest);
        }

        if (!sampleAlgorithmSelector.infer())
        {
            return sample::gLogger.reportFail(sampleTest);
        }
    }

    if (!sampleAlgorithmSelector.teardown())
    {
        return sample::gLogger.reportFail(sampleTest);
    }

    return sample::gLogger.reportPass(sampleTest);
}

你可能感兴趣的:(笔记,c++,caffe)