本文源于学习TensorRT文档《TensorRT-Developer-Guide》第4章“EXTENDING TENSORRT WITH CUSTOM LAYERS”的理解。
自定义层添加是通过扩展IPluginV2Ext和IPluginCreator类来实现:
对定义好的插件可以通过REGISTER_TENSORRT_PLUGIN(pluginCreator)
进行静态注册,并在使用时通过getPluginRegistry()
查询并使用。官方已经实现的插件有:
// 通过getPluginRegistry获取所有TensorRT插件,creator即IPluginCreator对象
auto creator = getPluginRegistry()->getPluginCreator(pluginName, pluginVersion);
const PluginFieldCollection* pluginFC = creator->getFieldNames();
// 填充该层参数信息,pluginData需要先通过PluginField分配堆上空间
PluginFieldCollection *pluginData = parseAndFillFields(pluginFC, layerFields);
// 使用层名和插件参数创建新的插件对象,创建在堆上,需要主动释放
IPluginV2 *pluginObj = creator->createPlugin(layerName, pluginData);
// 在网络上添加一层,并将该层和插件绑定,layer即IPluginV2Layer对象
auto layer = network.addPluginV2(&inputs[0], int(inputs.size()), pluginObj);
// TODO:创建最新的网络,并序列化引擎
// 销毁插件对象
pluginObj->destroy()
// TODO:释放TensorRT资源,network、engine、builder
// TODO:释放显存空间,如原网络参数信息pluginData
TensorRT的引擎会在序列化时内部存储IPluginV2插件的属性信息,并在反序列化时通过插件注册表进行查找,并通过IPluginV2::destroy()接口内部销毁。
过去的版本中,用户必须通过nvinfer1::IPluginFactory类在反序列化时创建插件,现在的TensorRT版本可以使用addPluginV2即可。例如:
// 使用Caffe解释器解析网络并添加插件
// 如果使用IPluginExt创建插件,需要搭配nvinfer1::IPluginFactory 和 nvinfer1::IPluginFactory
class FooPlugin : public IPluginExt
{
// TODO:创建插件实现方法
};
class MyPluginFactory :
public nvinfer1::IPluginFactory,
public nvcaffeparser1::IPluginFactoryExt
{
// TODO:创建插件的工厂方法
};
// 如果使用IPluginV2创建并注册插件,则不再需要实现nvinfer1::IPluginFactory,
// 但需要通过nvcaffeparser1::IPluginFactoryV2 和 IPluginCreator来完成注册
class FooPlugin : public IPluginV2
{
// TODO:创建插件实现方法
};
class FooPluginFactory : public nvcaffeparser1::IPluginFactoryV2
{
virtual nvinfer1::IPluginV2* createPlugin(...)
{
// TODO:创建并返回插件对象,如FooPlugin
}
bool isPlugin(const char* name)
{
// TODO:通过网络层的名字检验是否使用该插件
}
}
class FooPluginCreator : public IPluginCreator
{
// TODO:实现所有的插件创建
};
REGISTER_TENSORRT_PLUGIN(FooPluginCreator);
具体的插件创建实例可以查看:
该部分内容基本与创建时介绍的情况雷同,需要注意的是对于Caffe解释器,可以通过setPluginFactoryV2 和 IPluginFactoryV2使用自定义插件,那么在反序列化时创建的插件会按照 IPluginExt::destroy()中定义的内容内部销毁而无需手动调用,用户只需要销毁创建创建过程中的插件对象。
1、获取插件输出数据结构,检验是否可以和相邻层对接:
2、获取插件除了输入输出外,需要占用多大的空间存储数据,在builder中调用并预分配:
3、插件在创建阶段会多次配置、初始化、执行、中止,而运行时只会多次执行,配置、初始化、中止只执行一次,initialize申请的内存需要在terminate时被释放,其它的内存需要在destroy释放,所需要的插件为:
4、通过IPluginV2Ext可以实现输入输出的广播性质,需要实现:
IPluginCreator中用来从插件库中查找并创建插件的方法:
5.x.x版本中没有getOutputDataType、isOutputBroadcastAcrossBatch、canBroadcastInputAcrossBatch,configurePlugin是针对configureWithFormat的升级。在迁移到5.1.x时需要实现这些新特性。
virtual nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const = 0;
virtual bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const = 0;
virtual bool canBroadcastInputAcrossBatch(int inputIndex) const = 0;
virtual void configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast, const bool* outputIsBroadcast, PluginFormat floatFormat, int maxBatchSize) = 0;