CV最简单的分类算法——knn(k nearest neighbors)

邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法可以说是整个数据挖掘分类技术中最简单的方法了。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用她最接近的k个邻居来代表。

kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。

(以上简介来自百科)


根据KNN的算法思路,简单实现了一下KNN,一下为自己实现与opencv内置的KNN分类结果的比较。 (自己实现的没有进一步优化,比如在vote的时候可以使用优先队列进行简单的优化)

knn.h


#ifndef __KNN_H__
#define __KNN_H__

#include "opencv\ml.h"
#include "opencv\highgui.h"
#include <queue>

using namespace std;
using namespace cv;

struct Key_Index
{
	double dist;
	int label;
};

class KNN
{
public:
	KNN(){}
	void train(Mat train_data_mat, Mat labels_mat, int max_k);
	int classify(Mat sample, int k);
private:
	Mat train_data;
	Mat labels;
	int max_k;
	int k;
};

#endif


knn.cpp

#include "knn.h"
#include <map>

bool operator <(const Key_Index &a, const Key_Index &b)
{
	return a.dist < b.dist;
}

void KNN::train(Mat train_data_mat, Mat labels_mat, int max_k)
{
	this->train_data = train_data_mat;
	this->labels = labels_mat;
	this->max_k = max_k;
}

int KNN::classify(Mat sample, int k)
{
	Key_Index *arr = new Key_Index[train_data.rows];
	if (k > max_k)k = max_k;
	float *ptr_t;
	
	float sx = sample.at<float>(0, 0);
	float sy = sample.at<float>(0, 1);

	for (int i = 0; i < train_data.rows; ++i)
	{
		ptr_t = train_data.ptr<float>(i);
		arr[i].label = labels.at<int>(i, 0);
		float x = ptr_t[0] - sx;
		float y = ptr_t[1] - sy;
		arr[i].dist = x * x + y * y;
	}

	
	sort(arr, arr + train_data.rows);
	int max_label, max_count = -1;
	map<int, int> mp;

	mp.clear();

	for (int i = 0; i < k; ++i)
	{
		++mp[arr[i].label];
		if (mp[arr[i].label] > max_count)
		{
			max_count = mp[arr[i].label];
			max_label = arr[i].label;
		}
	}

	delete arr;
	return max_label;
}


main.cpp

#include "opencv\ml.h"
#include "opencv\highgui.h"
#include "knn.h"
#include <time.h>

using namespace cv;

int main()
{
	
	int labels[10] = { 0, 0, 0, 0, 0, 1, 1, 1, 1, 1 };
	Mat labels_mat(10, 1, CV_32SC1, labels);//a 10*1 matrix initializing by labels
	float training_data[10][2];//in fact ten points
	srand(time(0));
	
	for (int i = 0; i<5; i++)
	{
		training_data[i][0] = rand() % 255 + 1;
		training_data[i][1] = rand() % 255 + 1;
		training_data[i + 5][0] = rand() % 255 + 255;
		training_data[i + 5][1] = rand() % 255 + 255;
	}
	Mat training_data_mat(10, 2, CV_32FC1, training_data);

	int width = 512, height = 512;
	Mat imagea = Mat::zeros(height, width, CV_8UC3);
	Mat imageb = Mat::zeros(height, width, CV_8UC3);
	Vec3b red(0, 0, 255), green(0, 255, 0);

	//self define knn
	KNN knn;
	knn.train(training_data_mat, labels_mat, 5);

	for (int i = 0; i < imagea.rows; ++i)
	{
		for (int j = 0; j < imagea.cols; ++j)
		{
			const Mat sampleMat = (Mat_<float>(1, 2) << i, j);
			int result = knn.classify(sampleMat, 3);
			if (result != 0)
			{
				//image.at<Vec3b>(x1,x2) visit the point(x2, x1) of the image
				//x2 is the x and x1 is the y
				imagea.at<Vec3b>(j, i) = red;
			}
			else
			{
				imagea.at<Vec3b>(j, i) = green;
			}
		}
	}
	
	//call opencv's knn
	CvKNearest cvknn;
	cvknn.train(training_data_mat, labels_mat, Mat(), false, 5);

	for (int i = 0; i < imageb.rows; ++i)
	{
		for (int j = 0; j < imageb.cols; ++j)
		{
			const Mat sampleMat = (Mat_<float>(1, 2) << i, j);
			float result = cvknn.find_nearest(sampleMat, 3);
			if (result != 0)
			{
				//image.at<Vec3b>(x1,x2) visit the point(x2, x1) of the image
				//x2 is the x and x1 is the y
				imageb.at<Vec3b>(j, i) = red;
			}
			else
			{
				imageb.at<Vec3b>(j, i) = green;
			}
		}
	}
	
	// Mark the training data using solid circles and rectangles
	for (int i = 0; i<5; ++i)
	{
		circle(imagea, Point(training_data[i][0], training_data[i][1]),
			5, Scalar(255, 255, 255), -1, 8);
		rectangle(imagea, Point(training_data[i + 5][0], training_data[i + 5][1]),
			Point(training_data[i + 5][0] + 5, training_data[i + 5][1] + 5), 
			Scalar(0, 0, 0), -1, 8);

		circle(imageb, Point(training_data[i][0], training_data[i][1]),
			5, Scalar(255, 255, 255), -1, 8);
		rectangle(imageb, Point(training_data[i + 5][0], training_data[i + 5][1]),
			Point(training_data[i + 5][0] + 5, training_data[i + 5][1] + 5),
			Scalar(0, 0, 0), -1, 8);
	}

	imshow("KNN-my", imagea); // show the result of my knn
	imshow("KNN-cv", imageb); // show the result of the onpencv's knn
	cvWaitKey(0);

	return 0;
}

CV最简单的分类算法——knn(k nearest neighbors)_第1张图片


CV最简单的分类算法——knn(k nearest neighbors)_第2张图片

你可能感兴趣的:(C++,opencv,cv)