meanshift聚类的实现

参见http://blog.csdn.net/u014568921/article/details/45197027

// meanshift-cluster.cpp : 定义控制台应用程序的入口点。
//

#include "stdafx.h"
#include<iostream>
#include<vector>
#include<assert.h>
#include<cstdlib>
#include<time.h>
using namespace std;

#define MSTYPE double

class meanshift
{
private:
	struct MSData
	{
		vector<MSTYPE>data;
		//unsigned int dim;
		MSData(unsigned int d)
		{
			//dim = d;
			data.resize(d);
		}
	};
	vector<MSData>dataset;
	double kernel_bandwidth;

	MSData shiftvec(MSData vec)
	{
		MSData shiftvector(vec.data.size());
		
		double total_weight = 0;
		for (int i = 0; i<dataset.size(); i++){
			MSData temp = dataset[i];
			double distance = euclidean_distance(vec, temp);
			double weight = gaussian_kernel(distance);
			for (int j = 0; j<shiftvector.data.size(); j++){
				shiftvector.data[j] += temp.data[j] * weight;
			}
			total_weight += weight;
		}
		for (int i = 0; i<shiftvector.data.size(); i++){
			shiftvector.data[i] /= total_weight;
		}
		return shiftvector;
	}
	double gaussian_kernel(double distance){
		double temp = exp(-(distance*distance) / (kernel_bandwidth));
		return temp;
	}
	double euclidean_distance(const MSData &data1, const MSData &data2)
	{
		assert(data1.data.size() == data2.data.size());
		double sum = 0;
		for (int i = 0; i<data1.data.size(); i++){
			sum += (data1.data[i] - data2.data[i]) * (data1.data[i] - data2.data[i]);
		}
		return sqrt(sum);
	}


public:
	meanshift(double kernel_bandwidth) :kernel_bandwidth(kernel_bandwidth)
	{
		time_t t;
		srand(time(&t));
	}
	vector<MSData> apply()
	{
		vector<int> stop_moving;
		stop_moving.resize(dataset.size());
		vector<MSData> shifted_points = dataset;
		double max_shift_distance;
		do {
			max_shift_distance = 0;
			for (int i = 0; i<shifted_points.size(); i++){
				if (!stop_moving[i]) {
					MSData point_new = shiftvec(shifted_points[i]);
					double shift_distance = euclidean_distance(point_new, shifted_points[i]);
					if (shift_distance > max_shift_distance){
						max_shift_distance = shift_distance;
					}
#define EPSILON 0.00000001
					if (shift_distance <= EPSILON) {
						stop_moving[i] = 1;
					}
					shifted_points[i] = point_new;
				}
			}
			printf("max_shift_distance: %f\n", max_shift_distance);
		} while (max_shift_distance > EPSILON);
		
		
		for (int i = 0; i < dataset.size(); i++)
		{
			cout << "原始坐标 (" << dataset[i].data[0] << "," << dataset[i].data[1] << ")   滑动到  ("
				<< shifted_points[i].data[0] << "," << shifted_points[i].data[1] << ")" << endl;
		}
		
		return shifted_points;
	}
	
	void generatedata(int datanums,vector<int>&span)
	{
		for (int i = 0; i < datanums; i++)
		{
			MSData dd(span.size());
			for (int j = 0; j < span.size(); j++)
			{
				dd.data[j] = double(rand()) / (RAND_MAX + 1.0)*span[j];
			}
			dataset.push_back(dd);
		}
	}


};


int _tmain(int argc, _TCHAR* argv[])
{
	meanshift ms(4);
	vector<int>span;
	span.push_back(20);
	span.push_back(20);
	ms.generatedata(100, span);
	ms.apply();



	return 0;
}


结果如下图

meanshift聚类的实现_第1张图片

你可能感兴趣的:(聚类,MeanShift)