OpenCV----YOLOv7目标检测模型推理

题目要求:在上一篇OpenCV----YOLOv5目标检测模型推理博客中介绍了YOLOv5目标检测模型,本次兼容YOLOv7, 构建基于面向对象设计的目标检测模型框架。
yolov7论文地址:YOLOv7 paper
yolov7项目地址:YOLOv7 github
yolov7算得上目标检测领域内速度与精度的新高度,在模型设计、重参数化、模型缩放等方面做了细致的设计。本次基于最新的YOLOv7模型进行OpenCV部署,整体框架与前面YOLOv5、YOLACT相同,只是添加了YOLOv7类方法实现。

  • 结果展示

    注:图为yolov7 640x6340模型测试结果展示。可以看到,YOLOv7的检测更细致准确,右上角窗台的自行车也能很好检测到,任务领带也有检测到,任务框也比较准确。
  • 代码示例
#include 
#include 
#include 
#include 
#include 
#include 

#include "config.cpp"

using namespace cv;
using namespace dnn;
using namespace std;

class yolov7
{
public:
    yolov7(float confThreshold, float nmsThreshold, string model_path = "model/yolov7_640x640.onnx", const int keep_top_k = 200);
    yolov7(net_config& config);
    void detect(Mat& frame);
private:
    int inpWidth;
    int inpHeight;

    float confThreshold;
    float nmsThreshold;
    Net net;
    void drawPred(float conf, int left, int top, int right, int bottom, Mat& frame, int classid);
};

yolov7::yolov7(float confThreshold, float nmsThreshold, string model_path, const int keep_top_k)
{
    this->confThreshold = confThreshold;
    this->nmsThreshold = nmsThreshold;

    this->net = readNet(model_path);

    size_t pos = model_path.find("_");
    int len = model_path.length() - 6 - pos;
    string hxw = model_path.substr(pos + 1, len);
    pos = hxw.find("x");
    string h = hxw.substr(0, pos);
    len = hxw.length() - pos;
    string w = hxw.substr(pos + 1, len);
    this->inpHeight = stoi(h);
    this->inpWidth = stoi(w);
}

yolov7::yolov7(net_config& config)
{
    this->confThreshold = config.confThreshold;
    this->nmsThreshold = config.nmsThreshold;

    this->net = readNet(config.model_path);

    size_t pos = config.model_path.find("_");
    int len = config.model_path.length() - 6 - pos;
    string hxw = config.model_path.substr(pos + 1, len);
    pos = hxw.find("x");
    string h = hxw.substr(0, pos);
    len = hxw.length() - pos;
    string w = hxw.substr(pos + 1, len);
    this->inpHeight = stoi(h);
    this->inpWidth = stoi(w);
}

void yolov7::drawPred(float conf, int left, int top, int right, int bottom, Mat& frame, int classid)   // Draw the predicted bounding box
{
    //Draw a rectangle displaying the bounding box
    rectangle(frame, Point(left, top), Point(right, bottom), Scalar(0, 0, 255), 2);

    //Get the label for the class name and its confidence
    string label = format("%.2f", conf);
    label = string(class_names[classid+1]) + ":" + label;

    //Display the label at the top of the bounding box
    int baseLine;
    Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
    top = max(top, labelSize.height);
    //rectangle(frame, Point(left, top - int(1.5 * labelSize.height)), Point(left + int(1.5 * labelSize.width), top + baseLine), Scalar(0, 255, 0), FILLED);
    putText(frame, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0, 255, 0), 1);
    // **** video detection **** //
    // static const string kWinName = "YOLOv7 Object Detection in OpenCV";
	// namedWindow(kWinName, WINDOW_NORMAL);
	// imshow(kWinName, frame);
	// waitKey(10);
}

void yolov7::detect(Mat& frame)
{
    Mat blob = blobFromImage(frame, 1 / 255.0, Size(this->inpWidth, this->inpHeight), Scalar(0, 0, 0), true, false);
    this->net.setInput(blob);
    vector<Mat> outs;
    this->net.forward(outs, this->net.getUnconnectedOutLayersNames());

    int num_proposal = outs[0].size[0];
    int nout = outs[0].size[1];
    if (outs[0].dims > 2)
    {
        num_proposal = outs[0].size[1];
        nout = outs[0].size[2];
        outs[0] = outs[0].reshape(0, num_proposal);
    }
    /generate proposals
    vector<float> confidences;
    vector<Rect> boxes;
    vector<int> classIds;
    float ratioh = (float)frame.rows / this->inpHeight;
    float ratiow = (float)frame.cols / this->inpWidth;
    int n = 0, row_ind = 0; ///cx,cy,w,h,box_score,class_score
    float* pdata = (float*)outs[0].data;
    for (n = 0; n < num_proposal; n++)   ///ÌØÕ÷ͼ³ß¶È
    {
        float box_score = pdata[4];
        if (box_score > this->confThreshold)
        {
            Mat scores = outs[0].row(row_ind).colRange(5, nout);
            Point classIdPoint;
            double max_class_socre;
            // Get the value and location of the maximum score
            minMaxLoc(scores, 0, &max_class_socre, 0, &classIdPoint);
            max_class_socre *= box_score;
            if (max_class_socre > this->confThreshold)
            {
                const int class_idx = classIdPoint.x;
                float cx = pdata[0] * ratiow;  ///cx
                float cy = pdata[1] * ratioh;   ///cy
                float w = pdata[2] * ratiow;   ///w
                float h = pdata[3] * ratioh;  ///h

                int left = int(cx - 0.5 * w);
                int top = int(cy - 0.5 * h);

                confidences.push_back((float)max_class_socre);
                boxes.push_back(Rect(left, top, (int)(w), (int)(h)));
                classIds.push_back(class_idx);
            }
        }
        row_ind++;
        pdata += nout;
    }

    // Perform non maximum suppression to eliminate redundant overlapping boxes with
    // lower confidences
    vector<int> indices;
    NMSBoxes(boxes, confidences, this->confThreshold, this->nmsThreshold, indices);
    for (size_t i = 0; i < indices.size(); ++i)
    {
        int idx = indices[i];
        Rect box = boxes[idx];
        this->drawPred(confidences[idx], box.x, box.y,
            box.x + box.width, box.y + box.height, frame, classIds[idx]);
    }
}

// **** unit test ****//
// int main()
// {
//     yolov7 net(0.3, 0.5, "model/yolov7_640x640.onnx");
//     string imgpath = "inference/bus.jpg";
//     Mat srcimg = imread(imgpath);
//     net.detect(srcimg);

//     static const string kWinName = "Deep learning object detection in OpenCV";
//     namedWindow(kWinName, WINDOW_NORMAL);
//     imshow(kWinName, srcimg);
//     waitKey(0);
//     destroyAllWindows();
// }

你可能感兴趣的:(目标检测,opencv,opencv,目标检测,计算机视觉,yolov7)