STC跟踪算法示例代码

//STCTracker.h

#pragma once
#include 
using namespace cv;
using namespace std;

class STCTracker
{
public:
	STCTracker();
	~STCTracker();
	void init(const Mat frame, const Rect box);	
	void tracking(const Mat frame, Rect &trackBox);

private:
	void createHammingWin();
	void complexOperation(const Mat src1, const Mat src2, Mat &dst, int flag = 0);
	void getCxtPriorPosteriorModel(const Mat image);
	void learnSTCModel(const Mat image);

private:
	double sigma;			// scale parameter (variance)
	double alpha;			// scale parameter
	double beta;			// shape parameter
	double rho;				// learning parameter
	Point center;			// the object position
	Rect cxtRegion;			// context region

	Mat cxtPriorPro;		// prior probability
	Mat cxtPosteriorPro;	// posterior probability
	Mat STModel;			// conditional probability
	Mat STCModel;			// spatio-temporal context model
	Mat hammingWin;			// Hamming window
};

//STCTracker.cpp

#include "STCTracker.h"

STCTracker::STCTracker()
{

}

STCTracker::~STCTracker()
{

}

/************ Create a Hamming window ********************/
void STCTracker::createHammingWin()
{
	for (int i = 0; i < hammingWin.rows; i++)
	{
		for (int j = 0; j < hammingWin.cols; j++)
		{
			hammingWin.at(i, j) = (0.54 - 0.46 * cos( 2 * CV_PI * i / hammingWin.rows )) 
				* (0.54 - 0.46 * cos( 2 * CV_PI * j / hammingWin.cols ));
		}
	}
}

/************ Define two complex-value operation *****************/
void STCTracker::complexOperation(const Mat src1, const Mat src2, Mat &dst, int flag)
{
	CV_Assert(src1.size == src2.size);
	CV_Assert(src1.channels() == 2);

	Mat A_Real, A_Imag, B_Real, B_Imag, R_Real, R_Imag;
	vector planes;
	split(src1, planes);
	planes[0].copyTo(A_Real);
	planes[1].copyTo(A_Imag);

	split(src2, planes);
	planes[0].copyTo(B_Real);
	planes[1].copyTo(B_Imag);

	dst.create(src1.rows, src1.cols, CV_64FC2);
	split(dst, planes);
	R_Real = planes[0];
	R_Imag = planes[1];

	for (int i = 0; i < A_Real.rows; i++)
	{
		for (int j = 0; j < A_Real.cols; j++)
		{
			double a = A_Real.at(i, j);
			double b = A_Imag.at(i, j);
			double c = B_Real.at(i, j);
			double d = B_Imag.at(i, j);

			if (flag)
			{
				// division: (a+bj) / (c+dj)
				R_Real.at(i, j) = (a * c + b * d) / (c * c + d * d + 0.000001);
				R_Imag.at(i, j) = (b * c - a * d) / (c * c + d * d + 0.000001);
			}
			else
			{
				// multiplication: (a+bj) * (c+dj)
				R_Real.at(i, j) = a * c - b * d;
				R_Imag.at(i, j) = b * c + a * d;
			}
		}
	}
	merge(planes, dst);
}

/************ Get context prior and posterior probability ***********/
void STCTracker::getCxtPriorPosteriorModel(const Mat image)
{
	CV_Assert(image.size == cxtPriorPro.size);

	double sum_prior(0), sum_post(0);
	for (int i = 0; i < cxtRegion.height; i++)
	{
		for (int j = 0; j < cxtRegion.width; j++)
		{
			double x = j + cxtRegion.x;
			double y = i + cxtRegion.y;
			double dist = sqrt((center.x - x) * (center.x - x) + (center.y - y) * (center.y - y));

			// equation (5) in the paper
			cxtPriorPro.at(i, j) = exp(- dist * dist / (2 * sigma * sigma));
			sum_prior += cxtPriorPro.at(i, j);

			// equation (6) in the paper
			cxtPosteriorPro.at(i, j) = exp(- pow(dist / sqrt(alpha), beta));
			sum_post += cxtPosteriorPro.at(i, j);
		}
	}
	cxtPriorPro.convertTo(cxtPriorPro, -1, 1.0/sum_prior);
	cxtPriorPro = cxtPriorPro.mul(image);
	cxtPosteriorPro.convertTo(cxtPosteriorPro, -1, 1.0/sum_post);
}

/************ Learn Spatio-Temporal Context Model ***********/
void STCTracker::learnSTCModel(const Mat image)
{
	// step 1: Get context prior and posterior probability
	getCxtPriorPosteriorModel(image);

	// step 2-1: Execute 2D DFT for prior probability
	Mat priorFourier;
	Mat planes1[] = {cxtPriorPro, Mat::zeros(cxtPriorPro.size(), CV_64F)};
	merge(planes1, 2, priorFourier);
	dft(priorFourier, priorFourier);

	// step 2-2: Execute 2D DFT for posterior probability
	Mat postFourier;
	Mat planes2[] = {cxtPosteriorPro, Mat::zeros(cxtPosteriorPro.size(), CV_64F)};
	merge(planes2, 2, postFourier);
	dft(postFourier, postFourier);

	// step 3: Calculate the division
	Mat conditionalFourier;
	complexOperation(postFourier, priorFourier, conditionalFourier, 1);

	// step 4: Execute 2D inverse DFT for conditional probability and we obtain STModel
	dft(conditionalFourier, STModel, DFT_INVERSE | DFT_REAL_OUTPUT | DFT_SCALE);

	// step 5: Use the learned spatial context model to update spatio-temporal context model
	addWeighted(STCModel, 1.0 - rho, STModel, rho, 0.0, STCModel);
}

/************ Initialize the hyper parameters and models ***********/
void STCTracker::init(const Mat frame, const Rect box)
{
	// initial some parameters
	alpha = 2.25;
	beta = 1;
	rho = 0.075;
	sigma = 0.5 * (box.width + box.height);

	// the object position
	center.x = box.x + 0.5 * box.width;
	center.y = box.y + 0.5 * box.height;

	// the context region
	cxtRegion.width = 2 * box.width;
	cxtRegion.height = 2 * box.height;
	cxtRegion.x = center.x - cxtRegion.width * 0.5;
	cxtRegion.y = center.y - cxtRegion.height * 0.5;
	cxtRegion &= Rect(0, 0, frame.cols, frame.rows);

	// the prior, posterior and conditional probability and spatio-temporal context model
	cxtPriorPro = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);
	cxtPosteriorPro = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);
	STModel = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);
	STCModel = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);

	// create a Hamming window
	hammingWin = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);
	createHammingWin();

	Mat gray;
	cvtColor(frame, gray, CV_RGB2GRAY);

	// normalized by subtracting the average intensity of that region
	Scalar average = mean(gray(cxtRegion));
	Mat context;
	gray(cxtRegion).convertTo(context, CV_64FC1, 1.0, - average[0]);

	// multiplies a Hamming window to reduce the frequency effect of image boundary
	context = context.mul(hammingWin);

	// learn Spatio-Temporal context model from first frame
	learnSTCModel(context);
}

/******** STCTracker: calculate the confidence map and find the max position *******/
void STCTracker::tracking(const Mat frame, Rect &trackBox)
{
	Mat gray;
	cvtColor(frame, gray, CV_RGB2GRAY);

	// normalized by subtracting the average intensity of that region
	Scalar average = mean(gray(cxtRegion));
	Mat context;
	gray(cxtRegion).convertTo(context, CV_64FC1, 1.0, - average[0]);

	// multiplies a Hamming window to reduce the frequency effect of image boundary
	context = context.mul(hammingWin);

	// step 1: Get context prior probability
	getCxtPriorPosteriorModel(context);

	// step 2-1: Execute 2D DFT for prior probability
	Mat priorFourier;
	Mat planes1[] = {cxtPriorPro, Mat::zeros(cxtPriorPro.size(), CV_64F)};
	merge(planes1, 2, priorFourier);
	dft(priorFourier, priorFourier);

	// step 2-2: Execute 2D DFT for conditional probability
	Mat STCModelFourier;
	Mat planes2[] = {STCModel, Mat::zeros(STCModel.size(), CV_64F)};
	merge(planes2, 2, STCModelFourier);
	dft(STCModelFourier, STCModelFourier);

	// step 3: Calculate the multiplication
	Mat postFourier;
	complexOperation(STCModelFourier, priorFourier, postFourier, 0);

	// step 4: Execute 2D inverse DFT for posterior probability namely confidence map
	Mat confidenceMap;
	dft(postFourier, confidenceMap, DFT_INVERSE | DFT_REAL_OUTPUT| DFT_SCALE);

	// step 5: Find the max position
	Point point;
	minMaxLoc(confidenceMap, 0, 0, 0, &point);

	// step 6-1: update center, trackBox and context region
	center.x = cxtRegion.x + point.x;
	center.y = cxtRegion.y + point.y;
	trackBox.x = center.x - 0.5 * trackBox.width;
	trackBox.y = center.y - 0.5 * trackBox.height;
	trackBox &= Rect(0, 0, frame.cols, frame.rows);

	cxtRegion.x = center.x - cxtRegion.width * 0.5;
	cxtRegion.y = center.y - cxtRegion.height * 0.5;
	cxtRegion &= Rect(0, 0, frame.cols, frame.rows);

	// step 7: learn Spatio-Temporal context model from this frame for tracking next frame
	average = mean(gray(cxtRegion));
	gray(cxtRegion).convertTo(context, CV_64FC1, 1.0, - average[0]);
	context = context.mul(hammingWin);
	learnSTCModel(context);
}

//runSTCTracker.cpp

#include "STCTracker.h"

Rect box;
bool drawing_box = false;
bool gotBB = false;

static void mouseHandler(int event, int x, int y, int flags, void *param){
	switch( event ){
	case CV_EVENT_MOUSEMOVE:
		if (drawing_box){
			box.width = x-box.x;
			box.height = y-box.y;
		}
		break;
	case CV_EVENT_LBUTTONDOWN:
		drawing_box = true;
		box = Rect( x, y, 0, 0 );
		break;
	case CV_EVENT_LBUTTONUP:
		drawing_box = false;
		if( box.width < 0 ){
			box.x += box.width;
			box.width *= -1;
		}
		if( box.height < 0 ){
			box.y += box.height;
			box.height *= -1;
		}
		gotBB = true;
		break;
	}
}

int main(int argc, char * argv[])
{	
	const char *vedioFileName = "MAH00054.MP4";
	VideoCapture capture;
	//capture.open("optical_flow_input.avi");
	capture.open(vedioFileName);
	bool fromfile = true;

	if (!capture.isOpened())
	{
		cout << "capture device failed to open!" << endl;
		return -1;
	}
	//Register mouse callback to draw the bounding box
	cvNamedWindow("Tracker", CV_WINDOW_AUTOSIZE);
	cvSetMouseCallback("Tracker", mouseHandler, NULL ); 

	Mat frame;
	capture >> frame;
	while(!gotBB)
	{
		if (!fromfile)
			capture >> frame;

		imshow("Tracker", frame);
		if (cvWaitKey(20) == 27)
			return 1;
	}
	//Remove callback
	cvSetMouseCallback("Tracker", NULL, NULL ); 

	STCTracker stcTracker;
	stcTracker.init(frame, box);

	int frameCount = 0;
	while (1)
	{
		capture >> frame;
		if (frame.empty())
			return -1;
		double t = (double)cvGetTickCount();
		frameCount++;

		// tracking
		stcTracker.tracking(frame, box);	

		// show the result
		stringstream buf;
		buf << frameCount;
		string num = buf.str();
		putText(frame, num, Point(20, 30), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 0, 255), 3);
		rectangle(frame, box, Scalar(0, 0, 255), 3);
		imshow("Tracker", frame);

		t = (double)cvGetTickCount() - t;
		cout << "cost time: " << t / ((double)cvGetTickFrequency()*1000.) << endl;

		if ( cvWaitKey(1) == 27 )
			break;
	}

	return 0;
}



转载:http://blog.csdn.net/zouxy09

你可能感兴趣的:(算法,c/c++)