原始学习代码来自github.com/wang-xinyu/tensorrtx
tensorRT是用来专门部署前向的框架,可以使用解析的方式将网络部署到该框架下(这个方法就不多说了,自己可以找资料学习),这里重点学习使用Network API 重构网络。
// 创建一个网络生成器
IBuilder* builder = createInferBuilder(gLogger);
//生成一个空的网络
INetworkDefinition* network = builder->createNetwork();
//1、输入层
ITensor* data = network->addInput(INPUT_BLOB_NAME, dt, Dims3{3, INPUT_H, INPUT_W});
//2、卷积层
IConvolutionLayer* conv1 = network->addConvolution(*data, 64, DimsHW{11, 11}, weightMap["features.0.weight"], weightMap["features.0.bias"]);
conv1->setStride(DimsHW{4, 4});
conv1->setPadding(DimsHW{2, 2});
//3、激活层
IActivationLayer* relu1 = network->addActivation(*conv1->getOutput(0), ActivationType::kRELU);
//4、池化层
IPoolingLayer* pool1 = network->addPooling(*relu1->getOutput(0), PoolingType::kMAX, DimsHW{3, 3});
pool1->setStride(DimsHW{2, 2});
//5、全连接层
IFullyConnectedLayer* fc1 = network->addFullyConnected(*pool3->getOutput(0), 4096, weightMap["classifier.1.weight"], weightMap["classifier.1.bias"]);
//6、element操作
ew1 = network->addElementWise(*bn3->getOutput(0), *bn2->getOutput(0), ElementWiseOperation::kSUM);
//7、
ITensor* inputTensors[] = {relu1->getOutput(0), relu3->getOutput(0), relu5->getOutput(0), relu6->getOutput(0)};
IConcatenationLayer* cat1 = network->addConcatenation(inputTensors, 4);
使用上面常用层,就可以构建各种分类网络了,例如googlenet、resnet等
3.1、IPluginV2Ext和IPluginCreator是和自定义层相关的两个类,其作用如下:
3.2、注册和调用
TensorRT提供了通过调用REGISTER_TENSORRT_PLUGIN(pluginCreator)来注册插件的功能,该插件将插件创建器静态注册到插件注册表。 在运行时,可以使用外部函数getPluginRegistry()查询插件注册表。 插件注册表存储指向所有已注册插件创建者的指针,可用于基于插件名称和版本来查找特定的插件创建者。 TensorRT库包含可以加载到您的应用程序中的插件。
补充:
1、可以在官方插件中查看和使用tensorrt已经注册的插件;
2、要在应用程序中使用TensorRT注册的插件,必须加载libnvinfer_plugin.so库并且必须注册所有插件。 这可以通过在应用程序代码中调用initLibNvInferPlugins(void * logger,const char * libNamespace)()来完成。
3.3、自定义插件的格式
class MyPluginCreator : public BaseCreator
{
public:
MyPluginCreator();
~MyPluginCreator() override = default;
const char* getPluginName() const override;
const char* getPluginVersion() const override;
const PluginFieldCollection* getFieldNames() override;
IPluginV2Ext* createPlugin(const char* name, const PluginFieldCollection* fc) override;
IPluginV2Ext* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
private:
std::string mNamespace;
static PluginFieldCollection mFC;
static std::vector<PluginField> mPluginAttributes;
};
class MyLayerPlugin: public IPluginV2IOExt
{
public:
explicit MyLayerPlugin();
MyLayerPlugin(const void* data, size_t length);
~MyLayerPlugin();
// 该层返回输出的张量个数
int getNbOutputs() const override
{
return 1;
}
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
// 做一些初始化的工作,有些工作放到了析构函数中做
int initialize() override;
virtual void terminate() override {};
virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;}
virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override;
// 返回serialize时写了多少个字节到buffer了
virtual size_t getSerializationSize() const override;
// 把一些用得着的数据写入一个buffer
virtual void serialize(void* buffer) const override;
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
}
const char* getPluginType() const override;
const char* getPluginVersion() const override;
void destroy() override;
IPluginV2IOExt* clone() const override;
void setPluginNamespace(const char* pluginNamespace) override;
const char* getPluginNamespace() const override;
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
bool canBroadcastInputAcrossBatch(int inputIndex) const override;
void attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
void detachFromContext() override;
};