深度学习 客流统计 人流计数

该功能使用的darknet框架,用yolov3进行训练检测,跟踪用的简单距离跟踪逻辑。

一、网址:https://github.com/AlexeyAB/darknet

二、参考训练参考我的另一篇博客:https://blog.csdn.net/zhulong1984/article/details/82344685

三、跟踪代码:

#pragma once

#include "HeadDetect.h"
#include "opencv2/opencv.hpp"

using namespace cv;
using namespace std;

#define MAX_TRACK_COUNT		    30//目标的最大个数
#define MAX_TRAJECTORY_COUNT    200//目标轨迹的最大个数
#define MAX_MISS_FRAME			30//最大丢失帧数
#define MAX_TRACK_DIST          20000//最大跟踪距离

struct Tracker 
{
	int nID;
	int nMissTimes;
	bool bMatchID;
	bool bTrack;
	bool bStatOK;
	bool bInitDirection;
	int nInitDirection;
	cv::KalmanFilter *kf;
	cv::Point predPoint;
	cv::Point curCentre;
	cv::Rect curRect;
	int nTrajCount;
	cv::Point centres[MAX_TRAJECTORY_COUNT];
	cv::Rect blobs[MAX_TRAJECTORY_COUNT];
};

class HeadTrack
{
public:
	HeadTrack();
	~HeadTrack();

	int m_InCount;
	int m_OutCount;
	Tracker *m_pTrackers;//m_trackers[MAX_TRACK_COUNT];

	void trackInitial(int arrowStart, int arrowEnd, cv::Point lineStart, cv::Point lineEnd, cv::Mat detectRegion);

	void trackProcess(HDBox_Container *pHDBox_Container);

	void trackInOutStatistics();

	int pointInLineSide(cv::Point point, cv::Point lineStart, cv::Point lineEnd);

private:

	int m_trackID;
	int m_nArrowStart;
	int	m_nArrowEnd;
	cv::Point m_lineStart;
	cv::Point m_lineEnd;
	cv::Mat m_detectRegion;
};



#include "headTrack.h"

HeadTrack::HeadTrack()
{
	m_trackID = 0;
	m_InCount = 0;
	m_OutCount = 0;
	m_nArrowStart = 0;
	m_nArrowEnd = 0;

	m_pTrackers = new Tracker[MAX_TRACK_COUNT];
	memset(m_pTrackers, 0, sizeof(Tracker));
}

HeadTrack::~HeadTrack()
{
}

void HeadTrack::trackInitial(int arrowStart, int arrowEnd, cv::Point lineStart, cv::Point lineEnd, cv::Mat detectRegion)
{
	m_trackID = 0;
	m_InCount = 0;
	m_OutCount = 0;
	m_nArrowStart = arrowStart;
	m_nArrowEnd = arrowEnd;
	m_lineStart = lineStart;
	m_lineEnd = lineEnd;
	m_detectRegion = detectRegion.clone();
}

void HeadTrack::trackProcess(HDBox_Container *pHDBox_Container)
{
	int i = 0, j = 0;

	//把太小的rect删除
	HDBox_Container boxContainer;
	boxContainer.headCount = 0;
	int width = m_detectRegion.cols;
	uchar *pDetectData = m_detectRegion.data;

	int step1 = m_detectRegion.step1();
	//cv::imshow("m_detectRegion", m_detectRegion);
	//cv::waitKey(10);

	for (i = 0;i< pHDBox_Container->headCount;i++)
	{
		HDRect *pHDRect = &pHDBox_Container->candidates[i];
		int nx = (pHDRect->right + pHDRect->left) / 2;
		int ny = (pHDRect->top + pHDRect->bottom) / 2;
		int rectW = pHDRect->right - pHDRect->left;
		int rectH = pHDRect->bottom - pHDRect->top;
		if (rectW > 60 && rectH > 60 && (*(pDetectData + ny*width + nx) == 255))
		{
			boxContainer.candidates[boxContainer.headCount++] = pHDBox_Container->candidates[i];
		}
	}

	bool bMatch[HD_MAX_HEADS] = { false };
	for (i = 0;i< boxContainer.headCount;i++)
	{
		bMatch[i] = false;
	}

	for (i = 0; i < MAX_TRACK_COUNT; i++)
	{
		Tracker *pTracker = &m_pTrackers[i];
		if (pTracker->bTrack)
		{
			bool bMinst = false;
			int nMatchID = -1;
			int maxDist = MAX_TRACK_DIST;
			for (j = 0; j < boxContainer.headCount; j++)
			{
				if (!bMatch[j])
				{
					HDRect *pHDRect = &boxContainer.candidates[j];
					cv::Rect curRect;
					curRect.x = pHDRect->left;
					curRect.y = pHDRect->top;
					curRect.width = pHDRect->right - pHDRect->left;
					curRect.height = pHDRect->bottom - pHDRect->top;
					int nx = (pHDRect->left + pHDRect->right) / 2;
					int ny = (pHDRect->top + pHDRect->bottom) / 2;

					int dist = (pTracker->predPoint.x - nx)*(pTracker->predPoint.x - nx) + (pTracker->predPoint.y - ny)*(pTracker->predPoint.y - ny);
					if (dist < maxDist)
					{
						maxDist = dist;
						pTracker->curRect = curRect;//后面更新用
						pTracker->curCentre.x = nx;
						pTracker->curCentre.y = ny;
						nMatchID = j;
						bMinst = true;
					}
				}
			}

			//找到了blob
			if (bMinst)
			{
				bMatch[nMatchID] = true;

				HDRect *pHDRect = &boxContainer.candidates[nMatchID];
				cv::Rect curRect;
				curRect.x = pHDRect->left;
				curRect.y = pHDRect->top;
				curRect.width = pHDRect->right - pHDRect->left;
				curRect.height = pHDRect->bottom - pHDRect->top;
				int nx = (pHDRect->left + pHDRect->right) / 2;
				int ny = (pHDRect->top + pHDRect->bottom) / 2;

				pTracker->bMatchID = true;
				pTracker->nMissTimes = 0;
				pTracker->curCentre.x = nx;
				pTracker->curCentre.y = ny;
				pTracker->curRect = curRect;

				//更新预测值
				Mat measurement = Mat::zeros(2, 1, CV_32F);
				measurement.at(0) = (float)nx;
				measurement.at(1) = (float)ny;
				pTracker->kf->correct(measurement);

				Mat prediction = pTracker->kf->predict();
				pTracker->predPoint = Point(prediction.at(0), prediction.at(1)); //预测值(x',y')

				cv::Point centre = pTracker->centres[pTracker->nTrajCount - 1];
				if ((centre.x - nx)*(centre.x - nx) + (centre.y - ny)*(centre.y - ny) > 30)
				{
					pTracker->centres[pTracker->nTrajCount].x = nx;
					pTracker->centres[pTracker->nTrajCount].y = ny;
					pTracker->blobs[pTracker->nTrajCount] = curRect;
					pTracker->nTrajCount++;
					if (pTracker->nTrajCount >= MAX_TRAJECTORY_COUNT - 1)
					{
						pTracker->nTrajCount = MAX_TRAJECTORY_COUNT - 1;
						for (int k = 1; k < pTracker->nTrajCount; k++)
						{
							pTracker->centres[k - 1] = pTracker->centres[k];
							pTracker->blobs[k - 1] = pTracker->blobs[k];
						}
					}
				}
			}	
			else//没找到blob
			{
				pTracker->nMissTimes++;
				//Mat prediction = pTracker->kf->predict();
				//pTracker->predPoint = Point(prediction.at(0), prediction.at(1)); //预测值(x',y')
																							   //更新预测值
				Mat measurement = Mat::zeros(2, 1, CV_32F);
				measurement.at(0) = (float)pTracker->curCentre.x;
				measurement.at(1) = (float)pTracker->curCentre.y;
				pTracker->kf->correct(measurement);

				Mat prediction = pTracker->kf->predict();
				pTracker->predPoint = Point(prediction.at(0), prediction.at(1)); //预测值(x',y')

				if (pTracker->nMissTimes > MAX_MISS_FRAME)
				{
					pTracker->bTrack = false;
					delete pTracker->kf;
				}
			}
		}
	}

	//没有匹配上的,需要重新创建目标
	for (i = 0; i < boxContainer.headCount; i++)
	{
		HDRect *pHDRect = &boxContainer.candidates[i];
		cv::Rect curRect;
		curRect.x = pHDRect->left;
		curRect.y = pHDRect->top;
		curRect.width = pHDRect->right - pHDRect->left;
		curRect.height = pHDRect->bottom - pHDRect->top;
		int nx = (pHDRect->left + pHDRect->right) / 2;
		int ny = (pHDRect->top + pHDRect->bottom) / 2;
		if (!bMatch[i])
		{
			for (j = 0; j < MAX_TRACK_COUNT; j++)
			{
				Tracker *pTracker = &m_pTrackers[j];
				if (!pTracker->bTrack)
				{
					pTracker->bTrack = true;
					pTracker->bMatchID = true;
					pTracker->bStatOK = false;
					pTracker->bInitDirection = false;
					pTracker->nID = ++m_trackID;

					pTracker->curCentre.x = nx;
					pTracker->curCentre.y = ny;
					pTracker->curRect = curRect;
					pTracker->nMissTimes = 0;
					pTracker->predPoint.x = nx;
					pTracker->predPoint.y = ny;

					pTracker->kf = new cv::KalmanFilter(4, 2, 0);
					pTracker->kf->transitionMatrix = (Mat_(4, 4) << 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1);//转移矩阵A
					cv::setIdentity(pTracker->kf->measurementMatrix);                     //测量矩阵H
					cv::setIdentity(pTracker->kf->processNoiseCov, Scalar::all(1e-5));    //系统噪声方差矩阵Q
					cv::setIdentity(pTracker->kf->measurementNoiseCov, Scalar::all(1e-1));//测量噪声方差矩阵R
					cv::setIdentity(pTracker->kf->errorCovPost, Scalar::all(1));          //后验错误估计协方差矩阵P
					pTracker->kf->statePost = (Mat_(4, 1) << nx, ny, 0, 0);

					Mat prediction = pTracker->kf->predict();
					pTracker->predPoint = Point(prediction.at(0), prediction.at(1)); //预测值(x',y')

					pTracker->centres[0].x = nx;
					pTracker->centres[0].y = ny;
					pTracker->blobs[0] = curRect;
					pTracker->nTrajCount = 1;
					break;
				}
			}
		}
	}
}

int HeadTrack::pointInLineSide(cv::Point point, cv::Point lineStart, cv::Point lineEnd)
{
	int x0 = 0, x1 = 0;
	int y0 = 0, y1 = 0;
	int v0 = 0, v1 = 0;
	bool bFlag = false;
	bool bFlagX = false;
	bool bFlagY = false;

	x0 = lineStart.x;
	x1 = lineEnd.x;
	y0 = lineStart.y;
	y1 = lineEnd.y;

	////先保证点在线段内
	if (x0 > x1)
	{
		bFlagX = point.x > x1 && point.x < x0;
	}
	else
	{
		bFlagX = point.x x0;
	}

	if (y0 > y1)
	{
		bFlagY = point.y > y1 && point.y < y0;
	}
	else
	{
		bFlagY = point.y y0;
	}

	bFlag = bFlagX || bFlagY;
	if (!bFlag)
	{
		return 0;
	}

	v0 = (point.x - x0)*(y1 - y0) - (point.y - y0)*(x1 - x0);
	v1 = x1 - x0;

	if (x1 - x0 == 0)
	{
		if (v0 < 0)
		{
			return -1;
		}
		else
		{
			return 1;
		}
	}
	else
	{
		if (v0*v1 < 0)
		{
			return -1;
		}
		else
		{
			return 1;
		}
	}

	return 0;
}

void HeadTrack::trackInOutStatistics()
{
	int i = 0, j = 0;

	for (i = 0; i < MAX_TRACK_COUNT; i++)
	{
		Tracker *pTracker = &m_pTrackers[i];
		if (pTracker->bTrack && pTracker->nTrajCount > 20 && !pTracker->bStatOK)
		{
			if (!pTracker->bInitDirection)
			{
				int count0 = 0;
				for (j = 0; j < 10; j++)
				{
					int flag = pointInLineSide(pTracker->centres[j], m_lineStart, m_lineEnd);
					count0 += flag;
				}
				if (count0 > 0)
				{
					pTracker->nInitDirection = 1;
				}
				else
				{
					pTracker->nInitDirection = -1;
				}
			}

			int count1 = 0;
			for (j = pTracker->nTrajCount - 10; j < pTracker->nTrajCount - 1; j++)
			{
				int flag = pointInLineSide(pTracker->centres[j], m_lineStart, m_lineEnd);
				if (flag != 0 && pTracker->nInitDirection != flag)
				{
					count1++;
				}
			}
			
			if (count1 > 6)
			{
				if (pTracker->nInitDirection == m_nArrowStart)
				{
					m_InCount++;
				}
				else
				{
					m_OutCount++;
				}
				pTracker->bStatOK = true;
			}
		}
	}
}

四、效果展示:

深度学习 客流统计 人流计数_第1张图片 效果图1

 

深度学习 客流统计 人流计数_第2张图片 效果图2 --QQ交流:187100248

 

算法demo,封装了一个前期的演示效果,验证算法可以在windows上运行,算法编译使用的是vs2015+cuda需要的环境,运行需要有GTX的显卡:

demo网盘地址:链接:https://pan.baidu.com/s/1IKaGZOqDEO-McBnkwhunzg  提取码:c9ga 

部分原始数据,需要自己标注:链接:https://pan.baidu.com/s/1AOasuALCc7LX6cspyYUrcg 
提取码:rn13

 

 

你可能感兴趣的:(深度学习)