编写custom插件需要写两个类,分别如下:
AddPlugin:继承IPluginV2IOExt,插件类,用于编写插件需要实现的功能
AddPluginCreator:继承IPluginCreator,插件Factory类,用于创建插件
class AddPlugin: public nvinfer1::IPluginV2IOExt
class AddPluginCreator : public nvinfer1::IPluginCreator
后续工作:
将插件添加到TensorRT-OSS
将插件添加到onnx-tensorrt
参考链接:TensorRT: nvinfer1::IPluginV2IOExt Class Reference (nvidia.com)
构造函数一般需要实现两个:
第一个用于在创建plugin的过程,此时PluginCreator的createPlugin成员函数会调用
AddPlugin(nvinfer1::Weights valueToAdd)
第二个用于Plugin类的clone成员函数,PluginCreator的deserializePlugin成员函数
AddPlugin(const void *buffer, size_t length)
同时,需要禁用默认构造函数
AddPlugin() = delete
析构函数用于释放该plugin之前开辟的显存空间
~AddPlugin() {}
四个重要的成员函数
getOutputDimensions
TensorRT支持Dynamic-Shape时,batch这一维度必须是explicit的,也就是说,TensorRT处理的维度从以往的三维【3,-1,-1】变成了【1,3,-1,-1】。 根据输入的维度推导出该plugin输出的维度。
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* pInputDim, int nInputDim) override
{
/* get the dimension of an output tensor.
* index: the index of the output tensor.
* pInputDim: the input tensors.
* nInputDim: the number of input tensors. */
return pInputDim[0];
}
supportsFormatCombination
判断pos索引的输入/输出数据是否符合指定的format格式和type数据类型。
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override
{
/*
return true if plugin supports the format and datatype for the input/output indexed by pos.
* inputs are numbers [0, nbInputs-1]
* outputs are numbers [nbInputs, nbInputs+nbOutputs-1]
*
*/
switch(pos) {
case 0:
printf("inOut[0].type = %d, format[0]=%d\n", (int)inOut[0].type, (int)inOut[0].format);
return
((inOut[0].type == nvinfer1::DataType::kFLOAT || inOut[0].type == nvinfer1::DataType::kHALF) && inOut[0].format == nvinfer1::TensorFormat::kLINEAR)
|| (inOut[0].type == nvinfer1::DataType::kINT8 && inOut[0].format == nvinfer1::TensorFormat::kCHW4);
case 1:
printf("inOut[1].type = %d, format[1]=%d\n", (int)inOut[1].type, (int)inOut[1].format);
return inOut[0].format == inOut[1].format && inOut[0].type == inOut[1].type;
}
return false;
}
configurePlugin
判断输入和输出类型,数量是否正确。
virtual void configurePlugin(const nvinfer1::PluginTensorDesc* in, int nbInput, const nvinfer1::PluginTensorDesc* out, int nbOutput) override
{
/*
fields that a plugin might see for an input or output.
* scale is only valid when datatype is DataType::kINT8.
* TensorRT will set the value to -1.0f if it is invalid.
*/
m.dataType = in[0].type;
m.inputDim = in[0].dims;
m.scale = in[0].scale;
printf("configurePlugin type=%d, m.scale=%f\n", (int)out[0].type, m.scale);
}
enqueue
该plugin功能实现的接口,功能实现的cuda或cpu代码放入此。
int enqueue(int nBatch, const void * const *inputs, void **outputs, void* workspace, cudaStream_t stream) override;
四个注册到pluginFactory的信息
set/getPluginNamespace: 为plugin设置namespace名字,如果不设置则默认是"",需要注意的是同一个namespace下的plugin的名字相同会冲突。 getPluginType:获取plugin的name getPluginVersion: 获取plugin的版本
void setPluginNamespace(const char* szNamespace) override {}
const char* getPluginNamespace() const override {return "";}
const char* getPluginType() const override {return "AddPlugin";}
const char* getPluginVersion() const override {return "0";}
获取plugin的信息
getNbOutputs:获取plugin输出的个数,这个根据plugin的功能事先决定
int getNbOutputs() const override
{
return 1;
}
getOutputDataType:获取plugin输出数据的类型是否满足要求
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override
{
return inputTypes[0] == nvinfer1::DataType::kFLOAT ? nvinfer1::DataType::kFLOAT : nvinfer1::DataType::kINT8;
}
getWorkspaceSize:获取plugin运行占用的显存大小,需要确定这个op需要多大的显存空间去运行,在实际运行的时候就可以直接使用TensorRT开辟好的空间而不是自己去申请显存空间。
size_t getWorkspaceSize(int nMaxBatchSize) const override {return 0;}
initialize
初始化函数,在这个plugin准备开始运行之前执行。
int initialize() override {return 0;}
clone
将plugin对象克隆一份给TensorRT的builder,network和engine。
nvinfer1::IPluginV2IOExt* clone() const override
{
return new AddPlugin(&m, sizeof(m));
}
serialize
将plugin中的参数序列化写入buffer文件中
virtual void serialize(void *buffer) const override {
memcpy(buffer, &m, sizeof(m));
}
getSerializationSize:得到plugin中参数的内存大小,返回序列化时需要写多少字节到buffer中。(第二个析构函数)
virtual size_t getSerializationSize() const override
{
return sizeof(m);
}
两个plugin结束处理函数
terminate:继承父类,无操作 destroy:用于销毁plugin的对象
void terminate() override {}
void destroy() override { delete this; }
四个不重要的函数
bool canBroadcastInputAcrossBatch(int inputIndex) const override {return false;}
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const {return false;}
void attachToContext(cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, nvinfer1::IGpuAllocator* /*allocator*/) {}
void detachFromContext() {}
参考链接:TensorRT: nvinfer1::IPluginCreator Class Reference (nvidia.com)
构造函数和析构函数
构造函数用于初始化需要传入plugin中的权重和参数。
MyCustomPluginCreator::MyCustomPluginCreator()
{
mPluginAttributes.emplace_back(PluginField("in_channel", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("weight", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("bias", nullptr, PluginFieldType::kFLOAT32, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
四个plugin的相关信息获取/设定
const char* getPluginName() const override {return "AddPlugin";}
const char* getPluginVersion() const override {return "0";}
void setPluginNamespace(const char* szNamespace) override {}
const char* getPluginNamespace() const override {return "";}
createPlugin
通过PluginFieldCollection将plugin需要的权重和参数,并调用插件类的第一个构造函数创建plugin。
nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) override {
std::cout << __FUNCTION__ << std::endl;
float valueToAdd = 0;
for (int i = 0; i < fc->nbFields; i++) {
if (!strcmp(fc->fields[i].name, "valueToAdd")) {
valueToAdd = *(float *)fc->fields[i].data;
}
}
return new AddPlugin({nvinfer1::DataType::kFLOAT, &valueToAdd, 1});
}
deserializePlugin
从保存的engine文件中反序列化数据
nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override {
return new AddPlugin(serialData, serialLength);
}
getFieldNames
用于一系列PluginFiled对象,传入createPlugin中,创建plugin对象
const nvinfer1::PluginFieldCollection* getFieldNames() override {
std::cout << __FUNCTION__ << std::endl;
return nullptr;
}
当我们只想在某个项目中使用该plugin,可以通过在插件的实现cpp或cu文件中添加如下代码完成plugin的注册。
REGISTER_TENSORRT_PLUGIN(AddPluginCreator);
编写TensorRT-OSS的plugin时,插件类有时继承的类不同,而插件工厂类则继承的是BaseCreator。
在onnx-tensorrt中的builtin_op_importers.cpp文件中,我们采用DEFINE_BUILTIN_OP_IMPORTER去注册op,然后通过parse解析onnx模型,根据注册好的op去一个个解析并构建模型。