DBSCAN聚类算法的实现

参考wiki

https://en.wikipedia.org/wiki/DBSCAN

DBSCAN(Density-Based Spatial Clustering of Applications with Noise)是一个比较有代表性的基于密度的聚类算法。与划分和层次聚类方法不同,它将簇定义为密度相连的点的最大集合,能够把具有足够高密度的区域划分为簇,并可在噪声的空间数据库中发现任意形状的聚类。

DBSCAN中的的几个定义:
Ε邻域:给定对象半径为Ε内的区域称为该对象的Ε邻域;
核心对象:如果给定对象Ε领域内的样本点数大于等于MinPts,则称该对象为核心对象;
直接密度可达:对于样本集合D,如果样本点q在p的Ε领域内,并且p为核心对象,那么对象q从对象p直接密度可达。
密度可达:对于样本集合D,给定一串样本点p 1,p 2….p n,p= p 1,q= p n,假如对象p i从p i-1直接密度可达,那么对象q从对象p密度可达。
密度相连:存在样本集合D中的一点o,如果对象o到对象p和对象q都是密度可达的,那么p和q密度相联。
可以发现,密度可达是直接密度可达的传递闭包,并且这种关系是非对称的。密度相连是对称关系。DBSCAN目的是找到密度相连对象的最大集合。
Eg: 假设半径Ε=3,MinPts=3,点p的E领域中有点{m,p,p1,p2,o}, 点m的E领域中有点{m,q,p,m1,m2},点q的E领域中有点{q,m},点o的E领域中有点{o,p,s},点s的E领域中有点{o,s,s1}.
那么核心对象有p,m,o,s(q不是核心对象,因为它对应的E领域中点数量等于2,小于MinPts=3);
点m从点p直接密度可达,因为m在p的E领域内,并且p为核心对象;
点q从点p密度可达,因为点q从点m直接密度可达,并且点m从点p直接密度可达;
点q到点s密度相连,因为点q从点p密度可达,并且s从点p密度可达。

DBSCAN目的是找到密度相连对象的最大集合。

而正是因为密度相连的引入,dbscan可以发现任意形状的聚类。

这个算法概念很多,但理解之后实现起来很简单,大家可以自己试试。

下面是我的实现

// DBSCAN.cpp : 定义控制台应用程序的入口点。
//

#include "stdafx.h"
#include
#include
#include 
#include 
#include
#include
#include
using namespace std;


#define DBPOINTTYPE int

class dbscan
{
public:
	enum DBState{ unlabeled, core, border, noise };//数据点当前的状态
	struct DBPoint
	{
		int dim;//数据点的维数
		DBState state;
		vectordata;
		vectornear;//该点直接密度可达的点的集合,以dataset索引方式存储
		DBPoint(unsigned int d) :dim(d)
		{
			state = unlabeled;
		}
	};
	struct DBCluster
	{
		int label;//该簇的标号
		setborder;//该簇边界点的集合,以dataset索引方式存储
		setcore;//该簇核心点的集合,以dataset索引方式存储
	};
public:
	dbscan(int minpts, double rad) :MinPts(minpts), radius(rad)
	{
		time_t t;
		srand(time(&t));
	}
	~dbscan();
	void apply();
	void generate_dataset(int datasetsize, int Y, int X);
	void showresult(){
		for (int i = 0; i < clusters.size(); i++)
		{
			char string[100];
			sprintf(string, "第个%d簇:", i);
			cout << string << endl;
			cout << "核心:" << endl;
			for (set::iterator it = clusters[i]->core.begin(); it != clusters[i]->core.end(); it++)
			{
				sprintf(string, "编号%d   ", *it);
				cout << string << "(" << dataset[*(it)].data[0] << "," << dataset[*(it)].data[1] << ")" << endl;
			}
			cout << "边界:" << endl;
			for (set::iterator it = clusters[i]->border.begin(); it != clusters[i]->border.end(); it++)
			{
				sprintf(string, "编号%d   ", *it);
				cout << string << *it << "(" << dataset[*(it)].data[0] << "," << dataset[*(it)].data[1] << ")" << endl;
			}
			cout << endl << endl;
		}
	}

private:
	int MinPts;
	double radius;

	vectordataset;
	vectorclusters;

private:
	//bool isdensityconnected(DBCluster&cluster1, DBCluster&cluster2);//密度相连
	bool iscore(DBPoint&point);
	bool iscore(int k);
	double distance(DBPoint&point1, DBPoint&point2)
	{
		double dis = 0;
		for (int i = 0; i < point1.dim; i++)
			dis += pow(point1.data[i] - point2.data[i], 2.0);
		return sqrt(dis);
	}
	void expand(int i, DBCluster*clus);
	void mergecluster(int k1, int k2);//合并密度相连的两个簇
};

void dbscan::mergecluster(int k1, int k2)
{

	clusters[k1]->core.insert(clusters[k2]->core.begin(), clusters[k2]->core.end());
	clusters[k1]->border.insert(clusters[k2]->border.begin(), clusters[k2]->border.end());
	for (int i = k2 + 1; i < clusters.size(); i++)
		clusters[i]->label--;
}

void dbscan::generate_dataset(int datasetsize, int Y, int X)
{

	for (int i = 0; i < datasetsize; i++)
	{
		DBPoint point(2);
		point.data.resize(2);
		point.data[0] = X*double(rand()) / double(RAND_MAX + 1.0);
		point.data[1] = Y*double(rand()) / double(RAND_MAX + 1.0);
		dataset.push_back(point);
	}
}



void dbscan::apply()
{
	int k = 0;

	while (k < dataset.size() - 1)
	{
		for (int i = k + 1; i < dataset.size(); i++)
		{
			double dis = distance(dataset[k], dataset[i]);
			if (dis < radius)
			{
				dataset[k].near.push_back(i);
				dataset[i].near.push_back(k);
			}
		}
		k++;
	}
	for (int i = 0; i < dataset.size(); i++)
	{
		if (iscore(dataset[i]) && dataset[i].state == unlabeled)
		{
			DBCluster*clus = new DBCluster;
			clus->label = clusters.size();
			expand(i, clus);
			clusters.push_back(clus);
		}
	}
	for (int i = 0; i < dataset.size(); i++)
	{
		if (dataset[i].state == unlabeled)
			dataset[i].state = noise;
	}
	/*k = 0;
	while (k < clusters.size() - 1)//合并密度相连的集合
	{
	for (int i = k + 1; i < clusters.size(); i++)
	{
	if (!clusters[i]->core.empty() && !clusters[k]->core.empty())
	{
	setaa;
	//insert_iterator > >res_ins(aa, aa.begin());
	set_union(clusters[i]->border.begin(), clusters[i]->border.end(),//求并集
	clusters[k]->border.begin(), clusters[k]->border.end(), inserter(aa, aa.begin()));
	if (aa.size() != clusters[i]->border.size() + clusters[k]->border.size())//密度相连
	mergecluster(k, i);
	}
	}
	}
	for (int i = 0; i < clusters.size(); i++)
	{
	if (clusters[i]->core.empty())
	{
	delete clusters[i];
	clusters.erase(clusters.begin() + i, clusters.begin() + i + 1);
	}
	}*/
	//下面验证一下结果
	int sum = 0;
	for (int i = 0; i < clusters.size(); i++)
	{
		sum += clusters[i]->border.size() + clusters[i]->core.size();
	}
	for (int i = 0; i < dataset.size(); i++)
	{
		if (dataset[i].state == noise)
			sum++;
	}
	_ASSERTE(sum == dataset.size());

}

bool dbscan::iscore(DBPoint&point)
{
	return point.near.size() >= MinPts;
}
bool dbscan::iscore(int k)
{
	return dataset[k].near.size() >= MinPts;
}

/*void dbscan::expand(int k, DBCluster*clus)
{
if (clus->core.find(k) == clus->core.end())
{
clus->core.insert(k);//直接密度可达或者密度可达
dataset[k].state = core;
for (int i = 0; i < dataset[k].near.size(); i++)
{
if (!iscore(dataset[k].near[i]) && clus->border.find(i) == clus->border.end())
{
clus->border.insert(dataset[k].near[i]);//直接密度可达
dataset[i].state = border;
}
if (iscore(dataset[k].near[i]))
expand(dataset[k].near[i], clus);
}
}
}*/


void dbscan::expand(int k, DBCluster*clus)
{
	vectoraa, bb;
	aa.push_back(k);
	while (!aa.empty())
	{
		int gg = aa.back();
		aa.pop_back();
		if (iscore(dataset[gg]) && clus->core.find(gg) == clus->core.end())
		{
			clus->core.insert(gg);
			dataset[gg].state = core;
			for (int i = 0; i < dataset[gg].near.size(); i++)//只有核心对象可以扩展
			{
				if (dataset[dataset[gg].near[i]].state == unlabeled)
					aa.insert(aa.end(), dataset[gg].near[i]);
			}
		}
		else if (!iscore(dataset[gg]) && clus->border.find(gg) == clus->border.end())
		{
			clus->border.insert(gg);
			dataset[gg].state = border;
		}
	}
}

dbscan::~dbscan()
{
	for (int i = 0; i < clusters.size(); i++)
		delete clusters[i];
}




int _tmain(int argc, _TCHAR* argv[])
{

	/*setaa;
	aa.insert(10);
	aa.insert(25);
	aa.insert(35); aa.insert(5);

	setbb;
	bb.insert(10);
	bb.insert(25);
	bb.insert(35); bb.insert(99);
	setcc;
	set_union(aa.begin(), aa.end(), bb.begin(), bb.end(), inserter(cc, cc.begin()));

	vectordd;
	dd.push_back(24);
	dd.insert(dd.begin(), aa.begin(), aa.end());

	setee;
	ee.insert(10);
	ee.insert(250);
	aa.insert(ee.begin(), ee.end());

	set::iterator it;
	for (it = aa.begin(); it != aa.end(); it++)
	cout << *it << endl;*/
	dbscan db(4, 10);
	db.generate_dataset(100, 100, 100);
	db.apply();
	db.showresult();


	system("pause");
	return 0;
}


一个python版

  1. # scoding=utf-8  
  2. import pylab as pl  
  3. from collections import defaultdict,Counter  
  4.   
  5. points = [[int(eachpoint.split("#")[0]), int(eachpoint.split("#")[1])] for eachpoint in open("points","r")]  
  6.   
  7. # 计算每个数据点相邻的数据点,邻域定义为以该点为中心以边长为2*EPs的网格  
  8. Eps = 10  
  9. surroundPoints = defaultdict(list)  
  10. for idx1,point1 in enumerate(points):  
  11.     for idx2,point2 in enumerate(points):  
  12.         if (idx1 < idx2):  
  13.             if(abs(point1[0]-point2[0])<=Eps and abs(point1[1]-point2[1])<=Eps):  
  14.                 surroundPoints[idx1].append(idx2)  
  15.                 surroundPoints[idx2].append(idx1)  
  16.   
  17. # 定义邻域内相邻的数据点的个数大于4的为核心点  
  18. MinPts = 5  
  19. corePointIdx = [pointIdx for pointIdx,surPointIdxs in surroundPoints.iteritems() if len(surPointIdxs)>=MinPts]  
  20.   
  21. # 邻域内包含某个核心点的非核心点,定义为边界点  
  22. borderPointIdx = []  
  23. for pointIdx,surPointIdxs in surroundPoints.iteritems():  
  24.     if (pointIdx not in corePointIdx):  
  25.         for onesurPointIdx in surPointIdxs:  
  26.             if onesurPointIdx in corePointIdx:  
  27.                 borderPointIdx.append(pointIdx)  
  28.                 break  
  29.   
  30. # 噪音点既不是边界点也不是核心点  
  31. noisePointIdx = [pointIdx for pointIdx in range(len(points)) if pointIdx not in corePointIdx and pointIdx not in borderPointIdx]  
  32.   
  33. corePoint = [points[pointIdx] for pointIdx in corePointIdx]   
  34. borderPoint = [points[pointIdx] for pointIdx in borderPointIdx]  
  35. noisePoint = [points[pointIdx] for pointIdx in noisePointIdx]  
  36.   
  37. # pl.plot([eachpoint[0] for eachpoint in corePoint], [eachpoint[1] for eachpoint in corePoint], 'or')  
  38. # pl.plot([eachpoint[0] for eachpoint in borderPoint], [eachpoint[1] for eachpoint in borderPoint], 'oy')  
  39. # pl.plot([eachpoint[0] for eachpoint in noisePoint], [eachpoint[1] for eachpoint in noisePoint], 'ok')  
  40.   
  41. groups = [idx for idx in range(len(points))]  
  42.   
  43. # 各个核心点与其邻域内的所有核心点放在同一个簇中  
  44. for pointidx,surroundIdxs in surroundPoints.iteritems():  
  45.     for oneSurroundIdx in surroundIdxs:  
  46.         if (pointidx in corePointIdx and oneSurroundIdx in corePointIdx and pointidx < oneSurroundIdx):  
  47.             for idx in range(len(groups)):  
  48.                 if groups[idx] == groups[oneSurroundIdx]:  
  49.                     groups[idx] = groups[pointidx]  
  50.   
  51. # 边界点跟其邻域内的某个核心点放在同一个簇中  
  52. for pointidx,surroundIdxs in surroundPoints.iteritems():  
  53.     for oneSurroundIdx in surroundIdxs:  
  54.         if (pointidx in borderPointIdx and oneSurroundIdx in corePointIdx):  
  55.             groups[pointidx] = groups[oneSurroundIdx]  
  56.             break  
  57.   
  58. # 取簇规模最大的5个簇  
  59. wantGroupNum = 3  
  60. finalGroup = Counter(groups).most_common(3)  
  61. finalGroup = [onecount[0for onecount in finalGroup]  
  62.   
  63. group1 = [points[idx] for idx in xrange(len(points)) if groups[idx]==finalGroup[0]]  
  64. group2 = [points[idx] for idx in xrange(len(points)) if groups[idx]==finalGroup[1]]  
  65. group3 = [points[idx] for idx in xrange(len(points)) if groups[idx]==finalGroup[2]]  
  66.   
  67. pl.plot([eachpoint[0for eachpoint in group1], [eachpoint[1for eachpoint in group1], 'or')  
  68. pl.plot([eachpoint[0for eachpoint in group2], [eachpoint[1for eachpoint in group2], 'oy')  
  69. pl.plot([eachpoint[0for eachpoint in group3], [eachpoint[1for eachpoint in group3], 'og')  
  70.   
  71. # 打印噪音点,黑色  
  72. pl.plot([eachpoint[0for eachpoint in noisePoint], [eachpoint[1for eachpoint in noisePoint], 'ok')     
  73.   
  74. pl.show()  





你可能感兴趣的:(聚类算法,机器学习&&数据挖掘)