该功能使用的darknet框架,用yolov3进行训练检测,跟踪用的简单距离跟踪逻辑。
一、网址:https://github.com/AlexeyAB/darknet
二、参考训练参考我的另一篇博客:https://blog.csdn.net/zhulong1984/article/details/82344685
三、跟踪代码:
#pragma once
#include "HeadDetect.h"
#include "opencv2/opencv.hpp"
using namespace cv;
using namespace std;
#define MAX_TRACK_COUNT 30//目标的最大个数
#define MAX_TRAJECTORY_COUNT 200//目标轨迹的最大个数
#define MAX_MISS_FRAME 30//最大丢失帧数
#define MAX_TRACK_DIST 20000//最大跟踪距离
struct Tracker
{
int nID;
int nMissTimes;
bool bMatchID;
bool bTrack;
bool bStatOK;
bool bInitDirection;
int nInitDirection;
cv::KalmanFilter *kf;
cv::Point predPoint;
cv::Point curCentre;
cv::Rect curRect;
int nTrajCount;
cv::Point centres[MAX_TRAJECTORY_COUNT];
cv::Rect blobs[MAX_TRAJECTORY_COUNT];
};
class HeadTrack
{
public:
HeadTrack();
~HeadTrack();
int m_InCount;
int m_OutCount;
Tracker *m_pTrackers;//m_trackers[MAX_TRACK_COUNT];
void trackInitial(int arrowStart, int arrowEnd, cv::Point lineStart, cv::Point lineEnd, cv::Mat detectRegion);
void trackProcess(HDBox_Container *pHDBox_Container);
void trackInOutStatistics();
int pointInLineSide(cv::Point point, cv::Point lineStart, cv::Point lineEnd);
private:
int m_trackID;
int m_nArrowStart;
int m_nArrowEnd;
cv::Point m_lineStart;
cv::Point m_lineEnd;
cv::Mat m_detectRegion;
};
#include "headTrack.h"
HeadTrack::HeadTrack()
{
m_trackID = 0;
m_InCount = 0;
m_OutCount = 0;
m_nArrowStart = 0;
m_nArrowEnd = 0;
m_pTrackers = new Tracker[MAX_TRACK_COUNT];
memset(m_pTrackers, 0, sizeof(Tracker));
}
HeadTrack::~HeadTrack()
{
}
void HeadTrack::trackInitial(int arrowStart, int arrowEnd, cv::Point lineStart, cv::Point lineEnd, cv::Mat detectRegion)
{
m_trackID = 0;
m_InCount = 0;
m_OutCount = 0;
m_nArrowStart = arrowStart;
m_nArrowEnd = arrowEnd;
m_lineStart = lineStart;
m_lineEnd = lineEnd;
m_detectRegion = detectRegion.clone();
}
void HeadTrack::trackProcess(HDBox_Container *pHDBox_Container)
{
int i = 0, j = 0;
//把太小的rect删除
HDBox_Container boxContainer;
boxContainer.headCount = 0;
int width = m_detectRegion.cols;
uchar *pDetectData = m_detectRegion.data;
int step1 = m_detectRegion.step1();
//cv::imshow("m_detectRegion", m_detectRegion);
//cv::waitKey(10);
for (i = 0;i< pHDBox_Container->headCount;i++)
{
HDRect *pHDRect = &pHDBox_Container->candidates[i];
int nx = (pHDRect->right + pHDRect->left) / 2;
int ny = (pHDRect->top + pHDRect->bottom) / 2;
int rectW = pHDRect->right - pHDRect->left;
int rectH = pHDRect->bottom - pHDRect->top;
if (rectW > 60 && rectH > 60 && (*(pDetectData + ny*width + nx) == 255))
{
boxContainer.candidates[boxContainer.headCount++] = pHDBox_Container->candidates[i];
}
}
bool bMatch[HD_MAX_HEADS] = { false };
for (i = 0;i< boxContainer.headCount;i++)
{
bMatch[i] = false;
}
for (i = 0; i < MAX_TRACK_COUNT; i++)
{
Tracker *pTracker = &m_pTrackers[i];
if (pTracker->bTrack)
{
bool bMinst = false;
int nMatchID = -1;
int maxDist = MAX_TRACK_DIST;
for (j = 0; j < boxContainer.headCount; j++)
{
if (!bMatch[j])
{
HDRect *pHDRect = &boxContainer.candidates[j];
cv::Rect curRect;
curRect.x = pHDRect->left;
curRect.y = pHDRect->top;
curRect.width = pHDRect->right - pHDRect->left;
curRect.height = pHDRect->bottom - pHDRect->top;
int nx = (pHDRect->left + pHDRect->right) / 2;
int ny = (pHDRect->top + pHDRect->bottom) / 2;
int dist = (pTracker->predPoint.x - nx)*(pTracker->predPoint.x - nx) + (pTracker->predPoint.y - ny)*(pTracker->predPoint.y - ny);
if (dist < maxDist)
{
maxDist = dist;
pTracker->curRect = curRect;//后面更新用
pTracker->curCentre.x = nx;
pTracker->curCentre.y = ny;
nMatchID = j;
bMinst = true;
}
}
}
//找到了blob
if (bMinst)
{
bMatch[nMatchID] = true;
HDRect *pHDRect = &boxContainer.candidates[nMatchID];
cv::Rect curRect;
curRect.x = pHDRect->left;
curRect.y = pHDRect->top;
curRect.width = pHDRect->right - pHDRect->left;
curRect.height = pHDRect->bottom - pHDRect->top;
int nx = (pHDRect->left + pHDRect->right) / 2;
int ny = (pHDRect->top + pHDRect->bottom) / 2;
pTracker->bMatchID = true;
pTracker->nMissTimes = 0;
pTracker->curCentre.x = nx;
pTracker->curCentre.y = ny;
pTracker->curRect = curRect;
//更新预测值
Mat measurement = Mat::zeros(2, 1, CV_32F);
measurement.at(0) = (float)nx;
measurement.at(1) = (float)ny;
pTracker->kf->correct(measurement);
Mat prediction = pTracker->kf->predict();
pTracker->predPoint = Point(prediction.at(0), prediction.at(1)); //预测值(x',y')
cv::Point centre = pTracker->centres[pTracker->nTrajCount - 1];
if ((centre.x - nx)*(centre.x - nx) + (centre.y - ny)*(centre.y - ny) > 30)
{
pTracker->centres[pTracker->nTrajCount].x = nx;
pTracker->centres[pTracker->nTrajCount].y = ny;
pTracker->blobs[pTracker->nTrajCount] = curRect;
pTracker->nTrajCount++;
if (pTracker->nTrajCount >= MAX_TRAJECTORY_COUNT - 1)
{
pTracker->nTrajCount = MAX_TRAJECTORY_COUNT - 1;
for (int k = 1; k < pTracker->nTrajCount; k++)
{
pTracker->centres[k - 1] = pTracker->centres[k];
pTracker->blobs[k - 1] = pTracker->blobs[k];
}
}
}
}
else//没找到blob
{
pTracker->nMissTimes++;
//Mat prediction = pTracker->kf->predict();
//pTracker->predPoint = Point(prediction.at(0), prediction.at(1)); //预测值(x',y')
//更新预测值
Mat measurement = Mat::zeros(2, 1, CV_32F);
measurement.at(0) = (float)pTracker->curCentre.x;
measurement.at(1) = (float)pTracker->curCentre.y;
pTracker->kf->correct(measurement);
Mat prediction = pTracker->kf->predict();
pTracker->predPoint = Point(prediction.at(0), prediction.at(1)); //预测值(x',y')
if (pTracker->nMissTimes > MAX_MISS_FRAME)
{
pTracker->bTrack = false;
delete pTracker->kf;
}
}
}
}
//没有匹配上的,需要重新创建目标
for (i = 0; i < boxContainer.headCount; i++)
{
HDRect *pHDRect = &boxContainer.candidates[i];
cv::Rect curRect;
curRect.x = pHDRect->left;
curRect.y = pHDRect->top;
curRect.width = pHDRect->right - pHDRect->left;
curRect.height = pHDRect->bottom - pHDRect->top;
int nx = (pHDRect->left + pHDRect->right) / 2;
int ny = (pHDRect->top + pHDRect->bottom) / 2;
if (!bMatch[i])
{
for (j = 0; j < MAX_TRACK_COUNT; j++)
{
Tracker *pTracker = &m_pTrackers[j];
if (!pTracker->bTrack)
{
pTracker->bTrack = true;
pTracker->bMatchID = true;
pTracker->bStatOK = false;
pTracker->bInitDirection = false;
pTracker->nID = ++m_trackID;
pTracker->curCentre.x = nx;
pTracker->curCentre.y = ny;
pTracker->curRect = curRect;
pTracker->nMissTimes = 0;
pTracker->predPoint.x = nx;
pTracker->predPoint.y = ny;
pTracker->kf = new cv::KalmanFilter(4, 2, 0);
pTracker->kf->transitionMatrix = (Mat_(4, 4) << 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1);//转移矩阵A
cv::setIdentity(pTracker->kf->measurementMatrix); //测量矩阵H
cv::setIdentity(pTracker->kf->processNoiseCov, Scalar::all(1e-5)); //系统噪声方差矩阵Q
cv::setIdentity(pTracker->kf->measurementNoiseCov, Scalar::all(1e-1));//测量噪声方差矩阵R
cv::setIdentity(pTracker->kf->errorCovPost, Scalar::all(1)); //后验错误估计协方差矩阵P
pTracker->kf->statePost = (Mat_(4, 1) << nx, ny, 0, 0);
Mat prediction = pTracker->kf->predict();
pTracker->predPoint = Point(prediction.at(0), prediction.at(1)); //预测值(x',y')
pTracker->centres[0].x = nx;
pTracker->centres[0].y = ny;
pTracker->blobs[0] = curRect;
pTracker->nTrajCount = 1;
break;
}
}
}
}
}
int HeadTrack::pointInLineSide(cv::Point point, cv::Point lineStart, cv::Point lineEnd)
{
int x0 = 0, x1 = 0;
int y0 = 0, y1 = 0;
int v0 = 0, v1 = 0;
bool bFlag = false;
bool bFlagX = false;
bool bFlagY = false;
x0 = lineStart.x;
x1 = lineEnd.x;
y0 = lineStart.y;
y1 = lineEnd.y;
////先保证点在线段内
if (x0 > x1)
{
bFlagX = point.x > x1 && point.x < x0;
}
else
{
bFlagX = point.x x0;
}
if (y0 > y1)
{
bFlagY = point.y > y1 && point.y < y0;
}
else
{
bFlagY = point.y y0;
}
bFlag = bFlagX || bFlagY;
if (!bFlag)
{
return 0;
}
v0 = (point.x - x0)*(y1 - y0) - (point.y - y0)*(x1 - x0);
v1 = x1 - x0;
if (x1 - x0 == 0)
{
if (v0 < 0)
{
return -1;
}
else
{
return 1;
}
}
else
{
if (v0*v1 < 0)
{
return -1;
}
else
{
return 1;
}
}
return 0;
}
void HeadTrack::trackInOutStatistics()
{
int i = 0, j = 0;
for (i = 0; i < MAX_TRACK_COUNT; i++)
{
Tracker *pTracker = &m_pTrackers[i];
if (pTracker->bTrack && pTracker->nTrajCount > 20 && !pTracker->bStatOK)
{
if (!pTracker->bInitDirection)
{
int count0 = 0;
for (j = 0; j < 10; j++)
{
int flag = pointInLineSide(pTracker->centres[j], m_lineStart, m_lineEnd);
count0 += flag;
}
if (count0 > 0)
{
pTracker->nInitDirection = 1;
}
else
{
pTracker->nInitDirection = -1;
}
}
int count1 = 0;
for (j = pTracker->nTrajCount - 10; j < pTracker->nTrajCount - 1; j++)
{
int flag = pointInLineSide(pTracker->centres[j], m_lineStart, m_lineEnd);
if (flag != 0 && pTracker->nInitDirection != flag)
{
count1++;
}
}
if (count1 > 6)
{
if (pTracker->nInitDirection == m_nArrowStart)
{
m_InCount++;
}
else
{
m_OutCount++;
}
pTracker->bStatOK = true;
}
}
}
}
四、效果展示:
算法demo,封装了一个前期的演示效果,验证算法可以在windows上运行,算法编译使用的是vs2015+cuda需要的环境,运行需要有GTX的显卡:
demo网盘地址:链接:https://pan.baidu.com/s/1IKaGZOqDEO-McBnkwhunzg 提取码:c9ga
部分原始数据,需要自己标注:链接:https://pan.baidu.com/s/1AOasuALCc7LX6cspyYUrcg
提取码:rn13