


#include "BYTETracker.h"
#include "scrfd.h"
#include "mtcnn.h"
cv::Mat getsrc_roi(std::vector x0, std::vector dst)
    int size = dst.size();
    cv::Mat A = cv::Mat::zeros(size * 2, 4, CV_32FC1);
    cv::Mat B = cv::Mat::zeros(size * 2, 1, CV_32FC1);

    //[ x1 -y1 1 0] [a]       [x_1]
    //[ y1  x1 0 1] [b]   =   [y_1]
    //[ x2 -y2 1 0] [c]       [x_2]
    //[ y2  x2 0 1] [d]       [y_2]	

    for (int i = 0; i < size; i++)
        A.at(i << 1, 0) = x0[i].x;// roi_dst[i].x;
        A.at(i << 1, 1) = -x0[i].y;
        A.at(i << 1, 2) = 1;
        A.at(i << 1, 3) = 0;
        A.at(i << 1 | 1, 0) = x0[i].y;
        A.at(i << 1 | 1, 1) = x0[i].x;
        A.at(i << 1 | 1, 2) = 0;
        A.at(i << 1 | 1, 3) = 1;

        B.at(i << 1) = dst[i].x;
        B.at(i << 1 | 1) = dst[i].y;

    cv::Mat roi = cv::Mat::zeros(2, 3, A.type());
    cv::Mat AT = A.t();
    cv::Mat ATA = A.t() * A;
    cv::Mat R = ATA.inv() * AT * B;

    //roi = [a -b c;b a d ];

    roi.at(0, 0) = R.at(0, 0);
    roi.at(0, 1) = -R.at(1, 0);
    roi.at(0, 2) = R.at(2, 0);
    roi.at(1, 0) = R.at(1, 0);
    roi.at(1, 1) = R.at(0, 0);
    roi.at(1, 2) = R.at(3, 0);
    return roi;


cv::Mat faceAlign(cv::Mat& imageAlign, Bbox& finalBboxAlign)
    double dst_landmark[10] = {
        38.2946, 73.5318, 55.0252, 41.5493, 70.7299,
        51.6963, 51.5014, 71.7366, 92.3655, 92.2041 };
    for (int i = 0; i < 5; i++) {
        coord5points.push_back(cv::Point2f(dst_landmark[i], dst_landmark[i + 5]));

    for (int j = 0; j < 5; j = j + 1)
        //cv::circle(image, cvPoint(finalBbox[i].ppoint[j], finalBbox[i].ppoint[j + 5]), 2, CV_RGB(0, 255, 0), CV_FILLED);
        facePointsByMtcnn.push_back(cv::Point(finalBboxAlign.ppoint[j], finalBboxAlign.ppoint[j + 5]));

    cv::Mat warp_mat = cv::estimateAffinePartial2D(facePointsByMtcnn, coord5points);
    if (warp_mat.empty()) {
        warp_mat = getsrc_roi(facePointsByMtcnn, coord5points);
    warp_mat.convertTo(warp_mat, CV_32FC1);
    cv::Mat alignFace = cv::Mat::zeros(112, 112, imageAlign.type());
    warpAffine(imageAlign, alignFace, warp_mat, alignFace.size());
    return alignFace;

float* getFeatByMobileFaceNetNCNN(ncnn::Extractor ex, cv::Mat img)
    //cout << "getFeatByMobileFaceNetNCNN" << endl;
    float* feat = new float[128];
    ncnn::Mat in = ncnn::Mat::from_pixels_resize(img.data, ncnn::Mat::PIXEL_BGR, img.cols, img.rows, 112, 112);
    ex.input("data", in);
    ncnn::Mat out;
    ex.extract("fc1", out);

    for (int j = 0; j < out.w; j++)
        feat[j] = out[j];
    return feat;

static cv::Rect SquarePadding(cv::Rect facebox, int margin_rows, int margin_cols, bool max_b)
    int c_x = facebox.x + facebox.width / 2;
    int c_y = facebox.y + facebox.height / 2;
    int large = 0;
    if (max_b)
        large = (std::max)(facebox.height, facebox.width) / 2;
        large = (std::min)(facebox.height, facebox.width) / 2;
    cv::Rect rectNot(c_x - large, c_y - large, c_x + large, c_y + large);
    rectNot.x = (std::max)(0, rectNot.x);
    rectNot.y = (std::max)(0, rectNot.y);
    rectNot.height = (std::min)(rectNot.height, margin_rows - 1);
    rectNot.width = (std::min)(rectNot.width, margin_cols - 1);
    if (rectNot.height - rectNot.y != rectNot.width - rectNot.x)
        return SquarePadding(cv::Rect(rectNot.x, rectNot.y, rectNot.width - rectNot.x, rectNot.height - rectNot.y), margin_rows, margin_cols, false);

    return cv::Rect(rectNot.x, rectNot.y, rectNot.width - rectNot.x, rectNot.height - rectNot.y);

void processImage(const std::string& imagePath, cv::VideoWriter& videoWriter, const cv::Size& targetSize) {
    cv::Mat image = cv::imread(imagePath);

    // 调整图片大小
    cv::resize(image, image, targetSize);


void processFolder(const std::string& folderPath, cv::VideoWriter& videoWriter, const cv::Size& targetSize) {
    cv::String pattern = folderPath + "/*.jpg"; // 匹配 JPG 格式的图片
    std::vector imagePaths;
    cv::glob(pattern, imagePaths);

    for (const auto& imagePath : imagePaths) {
        processImage(imagePath, videoWriter, targetSize);

void getDirectoryNames(const std::string& folderPath, std::vector& directoryNames)
    std::string searchPattern = folderPath + "\\*";

    WIN32_FIND_DATAA findData;
    HANDLE hFind = FindFirstFileA(searchPattern.c_str(), &findData);

    if (hFind != INVALID_HANDLE_VALUE)
            std::string entryName = findData.cFileName;

            if ((findData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) && entryName != "." && entryName != "..")
                directoryNames.push_back(folderPath + "/" + entryName);
        } while (FindNextFileA(hFind, &findData));


struct Object_Label {    
    double o_x, o_y;  // 中心坐标
    double o_width, o_height;  // 长宽
    int name;

std::vector ReadObjectsFromTxt(const std::wstring& filePath)
    std::wifstream inputFile(filePath);

    if (!inputFile)
        std::wcerr << L"Failed to open input file: " << filePath << std::endl;
        return {};

    std::vector objects;
    std::wstring line;
    while (std::getline(inputFile, line))
        Object_Label obj;
        std::wstringstream ss(line);
        std::wstring token;

        // Parse the line using comma as the delimiter
        std::getline(ss, token, L',');

        double picname = std::stod(token.substr(token.find(L',') + 1));

        std::getline(ss, token, L',');

        obj.o_x = std::stod(token.substr(token.find(L',') + 1));

        std::getline(ss, token, L',');
        obj.o_y = std::stod(token.substr(token.find(L',') + 1));

        std::getline(ss, token, L',');
        obj.o_width = std::stod(token.substr(token.find(L',') + 1));

        std::getline(ss, token, L',');
        obj.o_height = std::stod(token.substr(token.find(L',') + 1));



    return objects;

double calculateIOU(double x1, double y1, double w1, double h1, double x2, double y2, double w2, double h2) {
    double left = std::max(x1, x2);
    double top = std::max(y1, y2);
    double right = std::min(x1 + w1, x2 + w2);
    double bottom = std::min(y1 + h1, y2 + h2);

    double intersection = std::max(0.0, right - left) * std::max(0.0, bottom - top);
    double union_area = w1 * h1 + w2 * h2 - intersection;

    return intersection / union_area;

int main(int argc, char *argv[])

    std::string modelPath = "./models";
    ncnn::Net squeezenet;
    SCRFD* detector = new SCRFD(modelPath);
    ONet* detector_mtcnn = new ONet(modelPath);
    ncnn::Extractor ex = squeezenet.create_extractor();

    ofstream out("xreal.txt");

    vector motaxreal;//存放mota 用于读取到txt文件中
    vector motpxreal;//存放motp 用于读取到txt文件中
    vector misxreal;//存放misDetection 用于读取到txt文件中
    vector falsexreal;//存放falseDetection 用于读取到txt文件中
    vector N_gtxreal;//存放N_gt 用于读取到txt文件中
    vector IDSWxreal;//存放IDSW 用于读取到txt文件中
    vector N_trxreal;//存放N_tr 用于读取到txt文件中
    vector IDF1xreal;//存放IDF1 用于读取到txt文件中
    vector HOTAaxreal;//存放HOTAa 用于读取到txt文件中

    // 设置输入文件夹和输出视频文件名
    std::string inputFolder = "./img";
    std::string outputVideo = "output_video.avi";
    std::wstring_convert> converter;

    std::vector onedirectoryNames;
    std::vector finaldirectoryNames;//存储指定文件夹下的子目录路径
    // 设置目标图片大小
    //cv::Size targetSize(640, 480);
    getDirectoryNames(inputFolder, onedirectoryNames);
    for (const auto& path : onedirectoryNames) {
        std::wstring directory = converter.from_bytes(path);
        std::wstring searchPath = directory + L"\\*.txt";
        std::vector objects_label;
        int misDetection = 0;//漏检目标数FP
        int falseDetection = 0;//误检目标数FN
        int N_gt = 0;//实际目标总数
        int N_tr = 0;//检测到的目标总数TP
        double IDFP = 0.0;//整个过程id变化目标数
        double IDFN = 0.0;
        double IDSW = 0.0;//与上一帧对比id切换的次数;
        double IDTP = 1.0;//初始id为firstId ,整个过程为firstId的次数为IDTP,(TPA)
        double mota = 0.0;
        double sum_distance = 0.0;
        double motp = 0.0;
        double idf1 = 0.0;
        double Ac = 0.0;
        double HOTAa = 0.0;

        WIN32_FIND_DATAW fileData;
        HANDLE hFind = FindFirstFileW(searchPath.c_str(), &fileData);

        if (hFind != INVALID_HANDLE_VALUE)
                std::wstring filePath = directory + L"\\" + fileData.cFileName;
                objects_label = ReadObjectsFromTxt(filePath);

                 Process the objects
                //for (const auto& obj : objects_label)
                //    // Do something with the object data
                //    std::cout << "Object: x=" << obj.o_x << ", y=" << obj.o_y
                //        << ", width=" << obj.o_width << ", height=" << obj.o_height << std::endl;
            } while (FindNextFileW(hFind, &fileData) != 0);


        getDirectoryNames(path, finaldirectoryNames);
        for (int i = 0; i < finaldirectoryNames.size(); i++) {
            // 读取第一张图像以获取视频的宽度和高度
            std::vector imagePaths;

            cv::glob(finaldirectoryNames[i], imagePaths);
            cv::Mat firstImage = cv::imread(imagePaths[0]);
            int width = firstImage.cols;
            int height = firstImage.rows;
            cv::Size targetSize(width, height);

            // 创建输出视频编写器
            cv::VideoWriter videoWriter(outputVideo, cv::VideoWriter::fourcc('M', 'J', 'P', 'G'), 25, cv::Size(width, height));

            if (!videoWriter.isOpened()) {
                std::cout << "无法创建输出视频编写器!" << std::endl;
                return 1;

            // 处理文件夹中的图片
            processFolder(finaldirectoryNames[i], videoWriter, targetSize);

            // 释放资源

            cv::VideoCapture mVideoCapture(outputVideo);

            //cv::VideoCapture mVideoCapture(0);
            if (!mVideoCapture.isOpened()) {
                std::cout << "fail to openn!" << std::endl;
                return 1;
            cv::Mat frame;
            mVideoCapture >> frame;
            int num_frames = 0;
            int fps = 30;
            BYTETracker tracker(fps, 3000);
            bool firstFrame = true;//判断是否是当前视频的检测到标注目标的第一帧
            int firstId = -1; //如果检测到标注目标的第一帧,把这第一帧的id记录下来,以备检查后面id是否变化

            while (true)
                mVideoCapture >> frame;

                if (frame.empty()) {
                    // 处理视频帧读取完毕的情况

                std::vector faceobjects;
                std::vector finalBbox;
                std::vector bbox;

                auto start = std::chrono::system_clock::now();
                detector->detect_scrfd(frame, faceobjects);
                auto end = std::chrono::system_clock::now();
                auto detect_time = std::chrono::duration_cast(end - start).count();//ms
                //detector->draw_faceobjects(frame, faceobjects);

                for (int i = 0; i < faceobjects.size(); i++) {
                    cv::Mat faceROI_Image;
                    ncnn::Mat in = ncnn::Mat::from_pixels_resize(faceROI_Image.data,
                        ncnn::Mat::PIXEL_BGR, faceROI_Image.cols, faceROI_Image.rows, 48, 48);
                    // 传入onet
                    Bbox faceBbox = detector_mtcnn->onetDetect(in, faceobjects[i].rect.x,
                        faceobjects[i].rect.y, faceROI_Image.cols, faceROI_Image.rows);
                    //faceBbox.score = faceobjects[i].prob;
                    /*for (int j = 0; j < 10; j++) {
                        std::cout << "faceBbox[0].ppoint[" < bbox;
                bool matchLibrary = FALSE;

                for (int i = 0; i < num_box; i++) {
                    bbox[i] = cv::Rect(finalBbox[i].x1, finalBbox[i].y1,
                        finalBbox[i].x2 - finalBbox[i].x1 + 1, finalBbox[i].y2 - finalBbox[i].y1 + 1);
                    bbox[i] = SquarePadding(bbox[i], frame.rows, frame.cols, true);

                    cv::Mat alignedFace = faceAlign(frame, finalBbox[i]);
                    //cv::imshow("alignedFace", alignedFace);
                    float* featDetect = getFeatByMobileFaceNetNCNN(ex, alignedFace);
                    for (int j = 0; j < 128; j++)
                        finalBbox[i].fects[j] = featDetect[j];
                        //cout << i << " " << featDetect[i] << "\n";

                //detector->draw_faceobjects(frame, faceobjects);
                //std::cout << "--------------detecting---------------" << std::endl;

                start = std::chrono::system_clock::now();
                std::vector output_stracks = tracker.update(finalBbox);
                end = std::chrono::system_clock::now();
                auto track_time = std::chrono::duration_cast(end - start).count();//us

                //std::cout << "output_stracks.size()" << output_stracks.size() << std::endl;

                std::vector detections;
                Object_Label object_label;
                bool correct_detection = false;
                int lastId;//上一帧的id,用判断与前一帧id是否变化
                double TPA = 1;
                double FNA = 0;
                double FPA = 0;

                for (unsigned long i = 0; i < output_stracks.size(); i++)
                    std::vector tlwh = output_stracks[i].tlwh;
                    bool vertical = tlwh[2] / tlwh[3] > 1.6;
                    if (tlwh[2] * tlwh[3] > 30 && !vertical)
                        bbox[i] = cv::Rect(tlwh[0], tlwh[1], tlwh[2] + 1, tlwh[3] + 1);
                        bbox[i] = SquarePadding(bbox[i], frame.rows, frame.cols, true);
                        cv::Scalar s = tracker.get_color(output_stracks[i].track_id);
                        cv::putText(frame, cv::format("%d %.1f%%", output_stracks[i].track_id, 100 * output_stracks[i].score),
                            cv::Point(bbox[i].x, bbox[i].y - 5), 0, 0.6, cv::Scalar(0, 0, 255), 1, cv::LINE_AA);

                        cv::rectangle(frame, cv::Rect(bbox[i].x, bbox[i].y, bbox[i].width, bbox[i].height), s, 2);

                        object_label.name = output_stracks[i].track_id;
                        object_label.o_x = static_cast(tlwh[0] + tlwh[2] / 2.0);
                        object_label.o_y = static_cast(tlwh[1] + tlwh[3] / 2.0);
                        object_label.o_width = static_cast(tlwh[2]);
                        object_label.o_height = static_cast(tlwh[3]);


                cv::putText(frame, cv::format("detect ms:%ld  track us:%ld  current frame:%d", detect_time, track_time, num_frames),
                    cv::Point(1, 40), cv::FONT_HERSHEY_PLAIN, 1, cv::Scalar(255, 255, 255), 1, 8);
                cv::imshow("bytetracker", frame);

                for (const Object_Label& det : detections) {
                    double iou = calculateIOU(
                        det.o_x - det.o_width / 2.0, 
                        det.o_y - det.o_height / 2.0, 
                        det.o_width, det.o_height,
                        objects_label[0].o_x - objects_label[0].o_width / 2.0,
                        objects_label[0].o_y - objects_label[0].o_height / 2.0,
                        objects_label[0].o_width, objects_label[0].o_height);
                    double distance = 
                        std::sqrt(std::pow(det.o_x - objects_label[0].o_x, 2) + std::pow(det.o_y - objects_label[0].o_y, 2));
                    if (iou >= 0.5 && distance <= 20.0) {
                        correct_detection = true;
                        N_tr++; //检测匹配成功的数量TP
                        sum_distance += distance;
                        if (firstFrame == true) {
                            firstId = det.name;
                            lastId = det.name;
                            firstFrame = false;
                        }else {
                            if (firstId != det.name) {

                                FPA ++;
                                FNA ++;

                            else {

                                TPA ++;
                        if (firstFrame == false) {
                            if (det.name != lastId) {
                        Ac += (abs(TPA) / (abs(TPA) + abs(FNA) + abs(FPA)));
                if (correct_detection == false) {
                if (cv::waitKey(30) == 27) // Wait for 'esc' key press to exit

            std::cout << "视频播放完成!" << std::endl;

        mota = 1 - (double)(misDetection + falseDetection + 2 * IDSW) / (double)N_gt;
        motp = sum_distance / (double)N_tr;

        idf1 = (2 * IDTP ) / (2 * IDTP + IDFP + IDFN);

        HOTAa = std::sqrt(abs(Ac / (double)(abs(N_tr) + abs(misDetection) + abs(falseDetection))));

        std::cout << "mota = " << mota << std::endl;
        std::cout << "motp = " << motp << std::endl;
        std::cout << "idf1 = " << idf1 << std::endl;
        std::cout << "HOTAa = " << HOTAa << std::endl;




    for (int i = 0; i < onedirectoryNames.size(); i++)
        //MOTA  MOTP  IDF1  HOTA  FP  FN  N_gt  IDs  N_tr
        out << i << "  " << motaxreal[i] 
                 << "  " << motpxreal[i]
                 << "  " << IDF1xreal[i]
                 << "  " << HOTAaxreal[i]
                 << "  " << misxreal[i]
                 << "  " << falsexreal[i]
                 << "  " << N_gtxreal[i]
                 << "  " << IDSWxreal[i]
                 << "  " << N_trxreal[i]               


