如果我们在使用TensorRT时有一些操作并不支持,我们可以自行编写将其作为TensorRT的插件层,从而使得这些不能支持的操作能在TensorRT中使用。
我们以上采样层为例,进行编写:
首先我们要先定义一个继承自TensorRT插件基类的Upsample类:
class Upsample : public IPluginExt
然后我们要实现该类的一些必要方法,首先是2个构造函数,一个是传参数构建,另一个是从序列化后的比特流构建:
Upsample(int scale = 2) : mScale(scale){
assert(mScale > 0);
}
//定义上采样倍数
Upsample(const void *data, size_t length)
{
const char *d = reinterpret_cast < const char *>(data), *a = d;
mScale = read<int>(d);
mDtype = read<DimsCHW>(d);
mCHW = read<DimsCHW>(d);
assert(mScale > 0);
assert(d == a + length);
}
~Upsample()
{
}
一些定义层输出信息的方法:
//模型的输出个数
int getNbOutputs() const override{
return 1;
}
//获取模型输出的形状
Dims getOutputDimensions(int index, const Dims *inputs, int nbInputDims)
override{
assert(nbInputDims == 1);
assert(inputs[0].nbDims == 3);
return DimsCHW(inputs[0].d[0], inputs[0].d[1] * mScale, inputs[0].d[2] * mScale);
}
根据输入的形状个数以及采用的数据类型检查合法性以及配置层参数的方法:
bool supportsFormat(DataType type, PluginFormat format) const override {
return (type == DataType::kFLOAT || type == DataType::kHALF || type == DataType::kINT8)
&& format == PluginFormat::kNCHW;
}
//检查层是否支持当前的数据类型和格式
void configureWithFormat(const Dims *inputDims, int nbInputs, const Dims *outputDims, int nbOutputs,
DataType type, PluginFormat format, int maxBatchSize) override
{
mDtype = type;
mCHW.c() = inputDims[0].d[0];
mCHW.h() = inputDims[0].d[1];
mCHW.w() = inputDims[0].d[2];
}
//配置层的参数
层的序列化方法:
size_t getSerializationSize() override {
return sizeof(mScale) + sizeof(mDtype) + sizeof(mCHW);
}
//输出序列化层所需的长度
void serialize(void *buffer) override {
char *d = reinterpret_cast<char *>(buffer), *a = d;
write(d, mScale);
write(d, mDtype);
write(d, mCHW);
assert(d == a + getSerializationSize());
}
//将层参数序列化为比特流
层的运算方法:
size_t getWorkspaceSize(int maxBatchSize) const override {
return 0;
}
//层运算需要的临时工作空间大小
int enqueue(int batchSize, const void *const *inputs, void **outputs, void *workspace,
cudaStream_t stream) override;
//层执行计算的具体操作
在enqueue中我们调用编写好的cuda kenerl来进行Upsample的计算
完成了Upsample类的定义,我们就可以直接在网络中添加我们编写的插件了,通过如下语句我们就定义一个上采样2倍的上采样层。addPluginExt的第一个输入是ITensor**类别,这是为了支持多输出的情况,第二个参数就是输入个数,第三个参数就是需要创建的插件类对象。
Upsample up(2);
auto upsamplelayer=network->addPluginExt(inputtensot,1,up)