如何使用TensorRT Network API 重构网络

原始学习代码来自github.com/wang-xinyu/tensorrtx
tensorRT是用来专门部署前向的框架,可以使用解析的方式将网络部署到该框架下(这个方法就不多说了,自己可以找资料学习),这里重点学习使用Network API 重构网络。

一、创建一个空的network

     // 创建一个网络生成器
    IBuilder* builder = createInferBuilder(gLogger);
    //生成一个空的网络
    INetworkDefinition* network = builder->createNetwork();

二、给空的network添加常用层

//1、输入层
ITensor* data = network->addInput(INPUT_BLOB_NAME, dt, Dims3{3, INPUT_H, INPUT_W});
//2、卷积层
IConvolutionLayer* conv1 = network->addConvolution(*data, 64, DimsHW{11, 11}, weightMap["features.0.weight"], weightMap["features.0.bias"]);
conv1->setStride(DimsHW{4, 4});
conv1->setPadding(DimsHW{2, 2});
//3、激活层
IActivationLayer* relu1 = network->addActivation(*conv1->getOutput(0), ActivationType::kRELU);
//4、池化层
IPoolingLayer* pool1 = network->addPooling(*relu1->getOutput(0), PoolingType::kMAX, DimsHW{3, 3});
pool1->setStride(DimsHW{2, 2});
//5、全连接层
IFullyConnectedLayer* fc1 = network->addFullyConnected(*pool3->getOutput(0), 4096, weightMap["classifier.1.weight"], weightMap["classifier.1.bias"]);
//6、element操作
ew1 = network->addElementWise(*bn3->getOutput(0), *bn2->getOutput(0), ElementWiseOperation::kSUM);
//7、
 ITensor* inputTensors[] = {relu1->getOutput(0), relu3->getOutput(0), relu5->getOutput(0), relu6->getOutput(0)};
IConcatenationLayer* cat1 = network->addConcatenation(inputTensors, 4);

使用上面常用层,就可以构建各种分类网络了,例如googlenet、resnet等

三、通过扩展IPluginV2Ext和IPluginCreator类来添加自定义层

3.1、IPluginV2Ext和IPluginCreator是和自定义层相关的两个类,其作用如下:

  • IPluginCreator:是自定义图层的创建者类,用户可以使用该类来获取插件名称,版本和插件字段参数。 它还提供了在网络构建阶段创建插件对象并在推理期间反序列化插件对象的方法。
  • IPluginV2Ext:实现自定义插件的基类,包含版本化和对其它格式和单精度的处理;

3.2、注册和调用
TensorRT提供了通过调用REGISTER_TENSORRT_PLUGIN(pluginCreator)来注册插件的功能,该插件将插件创建器静态注册到插件注册表。 在运行时,可以使用外部函数getPluginRegistry()查询插件注册表。 插件注册表存储指向所有已注册插件创建者的指针,可用于基于插件名称和版本来查找特定的插件创建者。 TensorRT库包含可以加载到您的应用程序中的插件。
补充:
1、可以在官方插件中查看和使用tensorrt已经注册的插件;
2、要在应用程序中使用TensorRT注册的插件,必须加载libnvinfer_plugin.so库并且必须注册所有插件。 这可以通过在应用程序代码中调用initLibNvInferPlugins(void * logger,const char * libNamespace)()来完成。

3.3、自定义插件的格式

class MyPluginCreator : public BaseCreator
{
public:
    MyPluginCreator();
    ~MyPluginCreator() override = default;
    const char* getPluginName() const override;
    const char* getPluginVersion() const override;
    const PluginFieldCollection* getFieldNames() override;
    IPluginV2Ext* createPlugin(const char* name, const PluginFieldCollection* fc) override;
    IPluginV2Ext* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
private:
    std::string mNamespace;
    static PluginFieldCollection mFC;
    static std::vector<PluginField> mPluginAttributes;
};
class MyLayerPlugin: public IPluginV2IOExt
{
        public:
            explicit MyLayerPlugin();
            MyLayerPlugin(const void* data, size_t length);
            ~MyLayerPlugin();
            // 该层返回输出的张量个数
            int getNbOutputs() const override
            {
                return 1;
            }
            Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
            // 做一些初始化的工作,有些工作放到了析构函数中做
            int initialize() override;
            virtual void terminate() override {};
            virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;}
            virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override;
            // 返回serialize时写了多少个字节到buffer了
            virtual size_t getSerializationSize() const override;
            // 把一些用得着的数据写入一个buffer
            virtual void serialize(void* buffer) const override;
            bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
                return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
            }
            const char* getPluginType() const override;
            const char* getPluginVersion() const override;
            void destroy() override;
            IPluginV2IOExt* clone() const override;
            void setPluginNamespace(const char* pluginNamespace) override;
            const char* getPluginNamespace() const override;
            DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
            bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
            bool canBroadcastInputAcrossBatch(int inputIndex) const override;
            void attachToContext(
                    cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
            void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
            void detachFromContext() override;
    };

你可能感兴趣的:(工程,神经网络)