目录
python tensorrt:激活函数:hard_sigmoid
Config
c++ diou_nms
这个感觉靠谱,还没试:
tensorrtx/yolov5 at master · wang-xinyu/tensorrtx · GitHub
这个不错,c++调通的版本:
yolov5转tensorrt c++_jacke121的专栏-CSDN博客
yolov5 tensorrt in8训练开源项目:
https://github.com/maggiez0138/yolov5_quant_sample
https://github.com/TrojanXu/yolov5-tensorrt
https://github.com/wang-xinyu/tensorrtx
The Pytorch implementation is ultralytics/yolov5.
Currently, we support yolov5 v1.0(yolov5s only), v2.0, v3.0 and v3.1.
NET
macro in yolov5.cppv2也是leakyrelu。
2.0:这个是匹配3.0的版本,用的leakyrelu,可以检测,v3.0自己训练的精度比较低
GitHub - BaofengZan/yolov5_2.0-TensorRt: U版yolov5 2.0的tensorrt加速
https://github.com/AIpakchoi/yolov5_tensorrt/blob/110762eea4a7a53a91bbce35f94239136db157d4/yolov5l/common.hpp
GitHub - baituhuangyu/yolov5-tensorrt: yolov5 tensorrt inference
https://github.com/Thinker-or-Dreamer/UAV-And-RobotArm/tree/master/yolov5
HardSwishLayer_TRT
https://github.com/hlld/tensorrt-yolov5
linux的:
2020.10.23 17天以前更新的,激活函数:kHARD_SIGMOID
https://github.com/wang-xinyu/tensorrtx/tree/master/yolov5
yololayer.h
修改自己类别个数:
static constexpr int CLASS_NUM = 1;
下面是阈值参数,nms阈值参数:
#define USE_FP16 // comment out this if want to use FP32
#define DEVICE 0 // GPU id
#define NMS_THRESH 0.4
#define CONF_THRESH 0.3
std::cerr << "yolov5_rt.exe -s s // serialize model to plan file" << std::endl;
std::cerr << "yolov5_rt.exe -e s -c 0 // detect cam" << std::endl;
std::cerr << "yolov5_rt.exe -e s -d samples // deserialize plan file and run inference" << std::endl;
后面3行是调用demo,分两步,编译和执行
问题:原版网络检测出来的框没问题
自己训练的,tensorrt检测的与pytorch检测出来的有偏差,原因还未找到。
可能是anchors的原因,但是没找到证据。
struct BBox
{
float x1, y1, x2, y2;
};
struct BBoxInfo
{
BBox box;
int label;
int classId; // For coco benchmarking
float prob;
};
std::vector diou_nms(const float nmsThresh, std::vector binfo)
{
auto overlap1D = [](float x1min, float x1max, float x2min, float x2max) -> float
{
if (x1min > x2min)
{
std::swap(x1min, x2min);
std::swap(x1max, x2max);
}
return x1max < x2min ? 0 : std::min(x1max, x2max) - x2min;
};
auto computeIoU = [&overlap1D](BBox& bbox1, BBox& bbox2) -> float
{
float overlapX = overlap1D(bbox1.x1, bbox1.x2, bbox2.x1, bbox2.x2);
float overlapY = overlap1D(bbox1.y1, bbox1.y2, bbox2.y1, bbox2.y2);
float area1 = (bbox1.x2 - bbox1.x1) * (bbox1.y2 - bbox1.y1);
float area2 = (bbox2.x2 - bbox2.x1) * (bbox2.y2 - bbox2.y1);
float overlap2D = overlapX * overlapY;
float u = area1 + area2 - overlap2D;
return u == 0 ? 0 : overlap2D / u;
};
//https://arxiv.org/pdf/1911.08287.pdf
auto R = [](BBox &bbox1,BBox &bbox2) ->float
{
float center1_x = (bbox1.x1 + bbox1.x2) / 2.f;
float center1_y = (bbox1.y1 + bbox1.y2) / 2.f;
float center2_x = (bbox2.x1 + bbox2.x2) / 2.f;
float center2_y = (bbox2.y1 + bbox2.y2) / 2.f;
float d_center = (center1_x - center2_x)* (center1_x - center2_x)
+ (center1_y - center2_y)*(center1_y - center2_y);
//smallest_enclosing box
float box_x1 = std::min({ bbox1.x1, bbox1.x2, bbox2.x1, bbox2.x2 });
float box_y1 = std::min({ bbox1.y1, bbox1.y2, bbox2.y1, bbox2.y2 });
float box_x2 = std::max({ bbox1.x1, bbox1.x2, bbox2.x1, bbox2.x2 });
float box_y2 = std::max({ bbox1.y1, bbox1.y2, bbox2.y1, bbox2.y2 });
float d_diagonal = (box_x1 - box_x2) * (box_x1 - box_x2) +
(box_y1 - box_y2) * (box_y1 - box_y2);
return d_center / d_diagonal;
};
std::stable_sort(binfo.begin(), binfo.end(),
[](const BBoxInfo& b1, const BBoxInfo& b2) { return b1.prob > b2.prob; });
std::vector out;
for (auto& i : binfo)
{
bool keep = true;
for (auto& j : out)
{
if (keep)
{
float overlap = computeIoU(i.box, j.box);
float r = R(i.box, j.box);
keep = (overlap-r) <= nmsThresh;
}
else
break;
}
if (keep) out.push_back(i);
}
return out;
}
设置阈值:
文件路径:class_yolo_detector.hpp
void parse_config()
{
_yolo_info.networkType = _vec_net_type[_config.net_type];
_yolo_info.configFilePath = _config.file_model_cfg;
_yolo_info.wtsFilePath = _config.file_model_weights;
_yolo_info.precision = _vec_precision[_config.inference_precison];
_yolo_info.deviceType = "kGPU";
auto npos = _yolo_info.wtsFilePath.find(".weights");
assert(npos != std::string::npos
&& "wts file file not recognised. File needs to be of '.weights' format");
_yolo_info.data_path = _yolo_info.wtsFilePath.substr(0, npos);
_yolo_info.calibrationTablePath = _yolo_info.data_path + "-calibration.table";
_yolo_info.inputBlobName = "data";
_infer_param.printPerfInfo = false;
_infer_param.printPredictionInfo = false;
_infer_param.calibImages = _config.calibration_image_list_file_txt;
_infer_param.calibImagesPath = "";
_infer_param.probThresh = _config.detect_thresh;
_infer_param.nmsThresh = 0.3;
}