tensorrt7.0 的文档路径:https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/index.html
https://www.cnblogs.com/shouhuxianjian/p/10532950.html
文章对attention ocr的全流程进行了讲解,内容非常好,尤其是对与网络的介绍。 整体流程为:encoder+decoder ----------encoder采用CNN+biLSTM模型 ------- decoder采用Attention模型
attention ocr最终回归出来的结果是69(69个字符)1728(2472)
其中包含了lstm,但是lstm在tensorrt的7.0以下版本都没有实现,所以最好是转onnx然后做trt的实现。
mPluginAttributes.emplace_back(PluginField("shareLocation", nullptr, PluginFieldType::kINT32, 1));
const PluginField* fields = fc->fields;
mClipBoxes = true;
for (int i = 0; i < fc->nbFields; ++i)
{
const char* attrName = fields[i].name;
if (!strcmp(attrName, "shareLocation"))
{
params.shareLocation = *(static_cast<const bool*>(fields[i].data));
}
}
BatchedNMSPlugin* plugin = new BatchedNMSPlugin(params);
BatchedNMSPlugin::BatchedNMSPlugin(const void* data, size_t length)
{
const char *d = reinterpret_cast<const char*>(data), *a = d;
param = read<NMSParameters>(d);
boxesSize = read<int>(d);
scoresSize = read<int>(d);
numPriors = read<int>(d);
mClipBoxes = read<bool>(d);
ASSERT(d == a + length);
}
b. getNbOutputs需要返回编写的plugin的输出数量,这个需要预先知道并填写;
int BatchedNMSPlugin::getNbOutputs() const
{
return 4;
}
c. getOutputDimensions函数,可以通过该函数获取输入的参数数量以及每个参数的维度信息
inputs[0].d[0]为输入1维度1的信息;
inputs[1].d[0]为输入2维度1的信息;
inputs[0].nbDims可以获取输入一有多少维;
nbInputDims为输入数量;
此处也可以获取一些全局变量,例如输入数据的维度,可以将其作为全局变量,以备enqueue函数使用。
此处需要预先知道输出的维度是多少并返回,如果只有2为可以使用return DimsHW( ,);输出三维可以使用return DimsCHW( ,,);等
d. enqueue函数是trt inference的主要入口,其中调用的函数需要编写cu文件来实现或者直接通过cuda的标准接口实现。
例如:
pluginStatus_t status = nmsInference(stream, batchSize, boxesSize, scoresSize, param.shareLocation,
param.backgroundLabelId, numPriors, param.numClasses, param.topK, param.keepTopK, param.scoreThreshold,
param.iouThreshold, DataType::kFLOAT, locData, DataType::kFLOAT, confData, keepCount, nmsedBoxes, nmsedScores,
nmsedClasses, workspace, param.isNormalized, false, mClipBoxes);
该接口需要在plugin/common/kernal.h里面进行定义,并在plugin/common/kernal/底下编写cu code。
e. getSerializationSize用来说明你在序列化写的数据长度或者是你在反序列化读的长度(是一致的);
f. serialize用于序列化,需要写参数。
void BatchedNMSPlugin::serialize(void* buffer) const
{
char *d = reinterpret_cast<char*>(buffer), *a = d;
write(d, param);
write(d, boxesSize);
write(d, scoresSize);
write(d, numPriors);
write(d, mClipBoxes);
ASSERT(d == a + getSerializationSize());
}
g. configurePlugin和getOutputDimensions的输入数据基本类似,同样可以获取输入参数的维度信息,并设置全局变量;
这个函数相对于getOutputDimensions而言,就是这个函数是必须要实现的且全局变量必须赋予值。
里面也可以对输入数据进行维度的报错。
h.clone函数需要调用plugin的构造并传递全局参数。
以上几个函数是必须要实现的,其他的函数参考示例就行。