TensorRT学习(三)通过自定义层扩展TensorRT

  本文源于学习TensorRT文档《TensorRT-Developer-Guide》第4章“EXTENDING TENSORRT WITH CUSTOM LAYERS”的理解。

通过C++API添加自定义层

  自定义层添加是通过扩展IPluginV2Ext和IPluginCreator类来实现:

  1. IPluginV2Ext:IPluginV2的升级版,实现自定义插件的基类,包含版本化和对其它格式和单精度的处理;
  2. IPluginCreator:自定义层的创建类,可以通过它获取插件的名称、版本信息、参数等,也提供网络创建阶段创建插件的方法,并在推理阶段反序列化它。

  对定义好的插件可以通过REGISTER_TENSORRT_PLUGIN(pluginCreator)进行静态注册,并在使用时通过getPluginRegistry()查询并使用。官方已经实现的插件有:

  • RPROI_TRT
  • Normalize_TRT
  • PriorBox_TRT
  • GridAnchor_TRT
  • NMS_TRT
  • LReLU_TRT
  • Reorg_TRT
  • Region_TRT
  • Clip_TRT
// 通过getPluginRegistry获取所有TensorRT插件,creator即IPluginCreator对象
auto creator = getPluginRegistry()->getPluginCreator(pluginName, pluginVersion);
const PluginFieldCollection* pluginFC = creator->getFieldNames();

// 填充该层参数信息,pluginData需要先通过PluginField分配堆上空间
PluginFieldCollection *pluginData = parseAndFillFields(pluginFC, layerFields);

// 使用层名和插件参数创建新的插件对象,创建在堆上,需要主动释放
IPluginV2 *pluginObj = creator->createPlugin(layerName, pluginData);

// 在网络上添加一层,并将该层和插件绑定,layer即IPluginV2Layer对象
auto layer = network.addPluginV2(&inputs[0], int(inputs.size()), pluginObj);

// TODO:创建最新的网络,并序列化引擎

// 销毁插件对象
pluginObj->destroy() 

// TODO:释放TensorRT资源,network、engine、builder
// TODO:释放显存空间,如原网络参数信息pluginData 

  TensorRT的引擎会在序列化时内部存储IPluginV2插件的属性信息,并在反序列化时通过插件注册表进行查找,并通过IPluginV2::destroy()接口内部销毁。
  过去的版本中,用户必须通过nvinfer1::IPluginFactory类在反序列化时创建插件,现在的TensorRT版本可以使用addPluginV2即可。例如:

// 使用Caffe解释器解析网络并添加插件
// 如果使用IPluginExt创建插件,需要搭配nvinfer1::IPluginFactory 和 nvinfer1::IPluginFactory
class FooPlugin : public IPluginExt
{
	// TODO:创建插件实现方法
};
class MyPluginFactory : 
public nvinfer1::IPluginFactory, 
public nvcaffeparser1::IPluginFactoryExt
{
	// TODO:创建插件的工厂方法
};

// 如果使用IPluginV2创建并注册插件,则不再需要实现nvinfer1::IPluginFactory,
// 但需要通过nvcaffeparser1::IPluginFactoryV2 和 IPluginCreator来完成注册
class FooPlugin : public IPluginV2
{
	// TODO:创建插件实现方法
};
class FooPluginFactory : public nvcaffeparser1::IPluginFactoryV2
{
	virtual nvinfer1::IPluginV2* createPlugin(...)
	{
		// TODO:创建并返回插件对象,如FooPlugin
	}
	bool isPlugin(const char* name)
	{
		// TODO:通过网络层的名字检验是否使用该插件
	}
}
class FooPluginCreator : public IPluginCreator
{
	// TODO:实现所有的插件创建
};
REGISTER_TENSORRT_PLUGIN(FooPluginCreator);

  具体的插件创建实例可以查看:

  • samplePlugin:自定义Caffe网络插件方法;
  • sampleFasterRCNN:通过TensorRT注册Caffe网络插件;
  • sampleUffSSD:对UFF(针对TensorFlow)添加插件。

使用自定义插件

  该部分内容基本与创建时介绍的情况雷同,需要注意的是对于Caffe解释器,可以通过setPluginFactoryV2 和 IPluginFactoryV2使用自定义插件,那么在反序列化时创建的插件会按照 IPluginExt::destroy()中定义的内容内部销毁而无需手动调用,用户只需要销毁创建创建过程中的插件对象。

API描述

IPluginV2的API

  1、获取插件输出数据结构,检验是否可以和相邻层对接:

  • getNbOutputs:验证输出张量数目;
  • getOutputDimensions:验证输入维度,获取输出维度;
  • supportsFormat:设置插件支持的数据类型,如何种处理精度;
  • getOutputDataType:插件输出数据的类型(NCHW、NC/2HW2 、NHWC8等,见PluginFormatType)。

  2、获取插件除了输入输出外,需要占用多大的空间存储数据,在builder中调用并预分配:

  • getWorkspaceSize

  3、插件在创建阶段会多次配置、初始化、执行、中止,而运行时只会多次执行,配置、初始化、中止只执行一次,initialize申请的内存需要在terminate时被释放,其它的内存需要在destroy释放,所需要的插件为:

  • configurePlugin:配置输入输出属性(数量、维度、类型、广播、格式选择、最大BatchSize),插件会选择最合适的算法和数据结构;
  • initialize:在插件配置和推理引擎创建之后使用,根据设置的数据结构配置并准备执行;
  • enqueue:插件实际处理过程,需输入运行BatchSize、输入指针、输出指针、缓存空间指针、CUDA流;
  • terminate:在引擎的上下文被释放时释放插件的所有资源;
  • clone:在需要一个独立插件时(新的builder、network、engine被创建)使用;
  • destroy:在builder、network、engine销毁时调用,释放对应的插件资源;
  • set/getPluginNamespace:设置或获取插件的命名空间,默认为""(空)。

  4、通过IPluginV2Ext可以实现输入输出的广播性质,需要实现:

  • canBroadcastInputAcrossBatch:判断输入张量是否可以在批中进行广播,能广播则返回true,TensorRT不会复制输入并使用同一输入副本;不能广播返回false,TensorRT会复制输入张量;
  • isOutputBroadcastAcrossBatch:指定索引的输出是否被广播。

IPluginCreator的API

  IPluginCreator中用来从插件库中查找并创建插件的方法:

  • getPluginName:获取插件的名字,并和getPluginType配合使用;
  • getPluginVersion:返回插件版本,TensorRT内部插件默认为1;
  • getFieldNames:返回PluginFieldCollection结构数据,包含添加插件的参数名和类型;
  • createPlugin:通过给定的PluginFieldCollection结构参数创建插件,需填充实际所需参数;
  • deserializePlugin:在TensorRT引擎根据插件名和版本信息内部调用,返回用于推理的插件对象;
  • set/getPluginNamespace:creator所在的插件库命名空间,默认为""(空)。

从5.x.x迁移到5.1.x

  5.x.x版本中没有getOutputDataType、isOutputBroadcastAcrossBatch、canBroadcastInputAcrossBatch,configurePlugin是针对configureWithFormat的升级。在迁移到5.1.x时需要实现这些新特性。

virtual nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const = 0;
virtual bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const = 0;
virtual bool canBroadcastInputAcrossBatch(int inputIndex) const = 0;
virtual void configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast, const bool* outputIsBroadcast, PluginFormat floatFormat, int maxBatchSize) = 0;

  

  

  

  

你可能感兴趣的:(深度学习,开源架构)