最近研究 TensorRT的自定义层,尝试的使用了自定义的FC层FC层和Upsample层Upsample层之后,重新回去看开发者手册,在此记录。
自己的理解,TensorRT的自定义层机制是有两个方法的,一种基于基类IPlugin,另一种是基于基类IPluginV2,从字面意思上来看IPluginV2就是最新版本。
IPlugin类的方法,是通过自己编写IPlugin的派生类IPluginExt
和nvinfer1::IPluginFactory类以及nvcaffeparser1::IPluginFactoryExt类
来实现自定义层的编辑和使用的。具体的实例可以看上面的链接。就是通过IPluginFactory类将IPluginExt类定义的Plugin实例化,然后在TensorRT中使用。在Developer Guide中对应的是如下这段代码:
//The following sample code adds a new plugin called FooPlugin:
class FooPlugin : public IPluginExt
{
...implement all class methods for your plugin
};
class MyPluginFactory : public nvinfer1::IPluginFactory, public nvcaffeparser1::IPluginFactoryExt
{
...implement all factory methods for your plugin
};
.
IPluginV2类的方法,是关于自定义层的另一种方法,这种方法是调用已经被注册的plugin的方法,这种方法不依赖nvinfer1::IPluginFactory
去实例化,而是使用nvcaffeparser1::IPluginFactoryV2
和 IPluginCreato
r代替,这种方法对应Developer Guide中的如下代码:
class FooPlugin : public IPluginV2
{
...implement all class methods for your plugin
};
class FooPluginFactory : public nvcaffeparser1::IPluginFactoryV2
{
virtual nvinfer1::IPluginV2* createPlugin(...)
{
...create and return plugin object of type FooPlugin
}
bool isPlugin(const char* name)
{
...check if layer name corresponds to plugin
}
}
class FooPluginCreator : public IPluginCreator
{
...implement all creator methods here
};
REGISTER_TENSORRT_PLUGIN(FooPluginCreator);
以前的博客已经用过了很多次基于IPlugin类的自定义层构造,我称之为V1用法,现在主要介绍两者的对比以及IPluginV2类的自定义层用法,称之为V2用法。
class IPlugin
{
public:
virtual int getNbOutputs() const = 0;
virtual Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) = 0;
virtual void configure(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, int maxBatchSize) = 0;
virtual int initialize() = 0;
virtual size_t getWorkspaceSize(int maxBatchSize) const = 0;
virtual int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) = 0;
virtual size_t getSerializationSize() = 0;
virtual void serialize(void* buffer) = 0;
virtual ~IPlugin() {}
};
class IPluginExt
{
public:
virtual int getTensorRTVersion() const
virtual bool supportsFormat(DataType type, PluginFormat format) const = 0;
virtual void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) = 0;
virtual ~IPluginExt() {}
}
class IPluginV2
{
public:
virtual int getTensorRTVersion() const { return NV_TENSORRT_VERSION;}
virtual const char* getPluginType() const = 0;
virtual const char* getPluginVersion() const = 0;
virtual int getNbOutputs() const = 0;
virtual Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) = 0;
virtual bool supportsFormat(DataType type, PluginFormat format) const = 0;
virtual void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) = 0;
virtual int initialize() = 0;
virtual void terminate() = 0;
virtual size_t getWorkspaceSize(int maxBatchSize) const = 0;
virtual int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) = 0;
virtual size_t getSerializationSize() const = 0;
virtual void serialize(void* buffer) const = 0;
virtual void destroy() = 0;
virtual IPluginV2* clone() const = 0;
virtual void setPluginNamespace(const char* pluginNamespace) = 0;
virtual const char* getPluginNamespace() const = 0;
protected:
virtual ~IPluginV2() {}
};
.
从类声明的对比来看,IPlugin类比IPluginV2类少了很多方法,但是IPlugin方法最后new出来的Plugin是继承于IPluginExt,IPluginExt才又继承与IPlugin,V1方法在IPluginExt类中声明了getTensorRTVersion(),supportsFormat(), configureWithFormat(),configure()
等方法。
所以,比较起来IPluginV2多了getPluginVersion(), getPluginType(), setPluginNamespace(), getPluginNamespace()
和IPluginV2* clone()
方法。
getPluginType()
是用来匹配 plugin creator 返回的plugin name的方法。
getPluginVersion()
是用来匹配 plugin creator 返回的plugin version的方法。比如在Developer Guide中,提到的这些RPROI_TRT, Normalize_TRT, PriorBox_TRT等已经被tensorRT封装好的Plugin都是属于version1的。
setPluginNamespace()
是用来设置这个plugin对象属于哪个namesapce的方法,在相同plugin library中的所有plugin对象都应该在同一个namespace中。
getPluginNamespace()
是用来返回该plugin对象所属namespace的。
IPluginV2* clone()
方法,每次创建包含此plugin的新builder,network或engine时,都会调用此方法。 它应该返回一个带有正确参数的新plugin对象。
其余的方法都基本相同,说明其实IPluginV2类和IPlugin类的定义方法相似。
但是,其实除了上面列出的和IPlugin类类似的方法之外,IPluginV2还有一个派生类IPluginExtV2
,通过支持不同的output数据类型和broadcast across batch扩展了IPluginV2类的功能。
class IPluginV2Ext : public IPluginV2
{
public:
virtual nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const = 0;
virtual bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const = 0;
virtual bool canBroadcastInputAcrossBatch(int inputIndex) const = 0;
virtual void configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims,
int nbOutputs, const DataType* inputTypes, const DataType* outputTypes,
const bool* inputIsBroadcast, const bool* outputIsBroadcast, PluginFormat floatFormat, int maxBatchSize)=0;
virtual ~IPluginV2Ext() {}
virtual void attachToContext(cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, IGpuAllocator* /*allocator*/) {}
virtual void detachFromContext() {}
virtual IPluginV2Ext* clone() const _TENSORRT_OVERRIDE = 0;
protected:
int getTensorRTVersion() const _TENSORRT_OVERRIDE
{
return (0x01000000 | (NV_TENSORRT_VERSION & 0xFFFFFF));
}
void configureWithFormat(const Dims* /*inputDims*/, int /*nbInputs*/, const Dims* /*outputDims*/,
int /*nbOutputs*/, DataType /*type*/, PluginFormat /*format*/, int /*maxBatchSize*/) _TENSORRT_OVERRIDE _TENSORRT_FINAL {}
};
.
这里挖个坑,这些方法的具体代码实践以后再试试看。
除了类中方法和派生类的不同,IPluginV2的自定义层的创建,还利用到了另一个类:IPluginCreator。
这个类是用来在Plugin Registry中查找和创建相应的plugin。
class IPluginCreator
{
public:
virtual int getTensorRTVersion() const { return NV_TENSORRT_VERSION; }
virtual const char* getPluginName() const = 0;
virtual const char* getPluginVersion() const = 0;
virtual const PluginFieldCollection* getFieldNames() = 0;
virtual IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) = 0;
virtual IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) = 0;
virtual void setPluginNamespace(const char* pluginNamespace) = 0;
virtual const char* getPluginNamespace() const = 0;
virtual ~IPluginCreator() {}
};
.
getPluginName()方法用来返回plugin name 和匹配IPluginExt::getPluginType的返回值
getPluginVersion()方法返回plugin version,默认是1
getFieldNames()方法返回PluginFieldCollection结构,这个结构会依据field name去填充相关字段,填充PluginFieldType等。
createPlugin()方法就是用PluginFieldCollection来创建Plugin,对应的数据会被填充。
deserializePlugin()方法返回用于推理的Plugin对象,TensorRT根据plugin名称和版本在内部调用这个对象。
set/getPluginNamespace()方法用于设置或者获取这个Creator实例所属的namespace
这个部分准备好好研究一下,所以单写一个博客。
我的想法是分为
(1)IPluginV2方法使用tensorRT已有的自定义层
(2)IPluginV2方法使用自己定义的自定义层
具体请看这里IPluginV2