K-Means算法具体内容可以参考我博客的相关文章,这里只使用Matlab对其进行实现,其他内容不多赘述
K-Means算法
1.生成随机样本点
首先利用 mvnrnd 函数生成3组满足高斯分布的数据,每组数据都是100*2的矩阵,也就相当于生成300个在坐标轴上的样本点
%% 第一组数据
mu1=[0 0]; %均值
S1=[0.1 0 ; 0 0.1]; %协方差
data1=mvnrnd(mu1,S1,100); %产生高斯分布数据
%% 第二组数据
mu2=[-1.25 1.25];
S2=[0.1 0 ; 0 0.1];
data2=mvnrnd(mu2,S2,100);
%% 第三组数据
mu3=[1.25 1.25];
S3=[0.1 0 ; 0 0.1];
data3=mvnrnd(mu3,S3,100);
mu1 、 mu2 、 mu3 是数据的均值,也就是你将每组点画在坐标轴上,其大致的中心位置坐标,例如对于上面的三组数据,中心点就分别为(0,0),(-1.25,1.25),(1.25,1.25),画在图上效果如下图
作图代码如下:
%% 显示数据
plot(data1(:,1),data1(:,2),'b+');
hold on;
plot(data2(:,1),data2(:,2),'b+');
plot(data3(:,1),data3(:,2),'b+');
2.初始化各矩阵
首先我们要将三个100*2的矩阵合并为一个300*2的矩阵 data = [data1;data2;data2]
然后初始化聚类中心,生成N行2列的零矩阵,这里的N是用户输入的想要聚为几类
还有就是要把data矩阵拷贝一份,尽量在算法执行过程中执行拷贝矩阵,而不去动data
%% 初始化变量
%% 初始化工作
data = [data1;data2;data3];
[m,n] = size(data); % m = 300,n = 2
center = zeros(N,n);% 初始化聚类中心,生成N行n列的零矩阵
pattern = data; % 将整个数据拷贝到pattern矩阵中
3.算法核心
一开始随机选取300个点中的N个点作为聚类中心(N是用户输入的聚类个数)。300个点分别计算到这N个中心点那一个最短,就将该点分为第几号。举个例子:
设有一个点的坐标是(0,0),分别有3个中心点(2,2),(1,1),(3,3),经过计算,(0,0)到(1,1)的距离是最短的,因此将(0,0)这个点划分为第2类
300个点全部划分完以后,假设用户输入的N是3,划分成60,90,150,然后计算60个点的中心点坐标(只要将60个点的x坐标加起来除以60,然后将y坐标加起来除以60,就能得到中心点),70个点的中心坐标,150个点的中心坐标,设这三个中心坐标为$(x_a,y_a)$,$(x_b,y_b)$,$(x_c,y_c)$,计算这三个中心点与之前随机选的三个中心点的距离是否小于一个阈值,如果都小于,则说明分类成功;只要有一个不满足,首先将这些新的中心坐标替换原来的中心坐标,然后重新分类
for x = 1 : N
center(x,:) = data(randi(300,1),:); % 第一次随机产生聚类中心 randi返回1*1的(1,300)的数
end
while true
distence = zeros(1,N); % 产生1行N列的零矩阵
num = zeros(1,N); % 产生1行N列的零矩阵
new_center = zeros(N,n); % 产生N行n列的零矩阵
%% 将所有的点打上标签1 2 3...N
for x = 1 : m
for y = 1 : N
distence(y) = norm(data(x,:) - center(y,:)); % norm函数计算到每个类的距离
end
[~,temp] = min(distence); %求最小的距离 ~是距离值,temp是第几个
pattern(x,n + 1) = temp;
end
k = 0;
%% 将所有在同一类里的点坐标全部相加,计算新的中心坐标
for y = 1 : N
for x = 1 : m
if pattern(x,n + 1) == y
new_center(y,:) = new_center(y,:) + pattern(x,1:n);
num(y) = num(y) + 1;
end
end
new_center(y,:) = new_center(y,:) / num(y);
if norm(new_center(y,:) - center(y,:)) < 0.1
k = k + 1;
end
end
if k == N
break;
else
center = new_center;
end
end
[m, n] = size(pattern); %[m,n] = [300,3]
4.绘制聚类后的数据点图
figure;
hold on;
for i = 1 : m
if pattern(i,n) == 1
plot(pattern(i,1),pattern(i,2),'r*');
plot(center(1,1),center(1,2),'ko');
elseif pattern(i,n) == 2
plot(pattern(i,1),pattern(i,2),'g*');
plot(center(2,1),center(2,2),'ko');
elseif pattern(i,n) == 3
plot(pattern(i,1),pattern(i,2),'b*');
plot(center(3,1),center(3,2),'ko');
elseif pattern(i,n) == 4
plot(pattern(i,1),pattern(i,2),'y*');
plot(center(4,1),center(4,2),'ko');
else
plot(pattern(i,1),pattern(i,2),'m*');
plot(center(5,1),center(5,2),'ko');
end
end
完整代码
clear;
clc;
N = input('请设置聚类数目:');%设置聚类数目
%% 第一组数据
mu1=[0 0]; %均值
S1=[0.1 0 ; 0 0.1]; %协方差
data1=mvnrnd(mu1,S1,100); %产生高斯分布数据
%% 第二组数据
mu2=[-1.25 1.25];
S2=[0.1 0 ; 0 0.1];
data2=mvnrnd(mu2,S2,100);
%% 第三组数据
mu3=[1.25 1.25];
S3=[0.1 0 ; 0 0.1];
data3=mvnrnd(mu3,S3,100);
%% 显示数据
plot(data1(:,1),data1(:,2),'b+');
hold on;
plot(data2(:,1),data2(:,2),'b+');
plot(data3(:,1),data3(:,2),'b+');
%% 初始化工作
data = [data1;data2;data3];
[m,n] = size(data); % m = 300,n = 2
center = zeros(N,n);% 初始化聚类中心,生成N行n列的零矩阵
pattern = data; % 将整个数据拷贝到pattern矩阵中
%% 算法
for x = 1 : N
center(x,:) = data(randi(300,1),:); % 第一次随机产生聚类中心 randi返回1*1的(1,300)的数
end
while true
distence = zeros(1,N); % 产生1行N列的零矩阵
num = zeros(1,N); % 产生1行N列的零矩阵
new_center = zeros(N,n); % 产生N行n列的零矩阵
%% 将所有的点打上标签1 2 3...N
for x = 1 : m
for y = 1 : N
distence(y) = norm(data(x,:) - center(y,:)); % norm函数计算到每个类的距离
end
[~,temp] = min(distence); %求最小的距离 ~是距离值,temp是第几个
pattern(x,n + 1) = temp;
end
k = 0;
%% 将所有在同一类里的点坐标全部相加,计算新的中心坐标
for y = 1 : N
for x = 1 : m
if pattern(x,n + 1) == y
new_center(y,:) = new_center(y,:) + pattern(x,1:n);
num(y) = num(y) + 1;
end
end
new_center(y,:) = new_center(y,:) / num(y);
if norm(new_center(y,:) - center(y,:)) < 0.1
k = k + 1;
end
end
if k == N
break;
else
center = new_center;
end
end
[m, n] = size(pattern); %[m,n] = [300,3]
%% 最后显示聚类后的数据
figure;
hold on;
for i = 1 : m
if pattern(i,n) == 1
plot(pattern(i,1),pattern(i,2),'r*');
plot(center(1,1),center(1,2),'ko');
elseif pattern(i,n) == 2
plot(pattern(i,1),pattern(i,2),'g*');
plot(center(2,1),center(2,2),'ko');
elseif pattern(i,n) == 3
plot(pattern(i,1),pattern(i,2),'b*');
plot(center(3,1),center(3,2),'ko');
elseif pattern(i,n) == 4
plot(pattern(i,1),pattern(i,2),'y*');
plot(center(4,1),center(4,2),'ko');
else
plot(pattern(i,1),pattern(i,2),'m*');
plot(center(5,1),center(5,2),'ko');
end
end
执行的GIF图如下:
存在的问题以及改进方法
这只是一个比较简单的K-Means聚类代码,其中可能存在两个问题:
死循环
聚类不准确
第一个问题产生的原因很简单,如果用笔算过K-Means就会知道,对于一个数据集,可能的聚类方式不止一种,并且存在确实无法达到所有的聚类中心差都小于阈值的情况。解决办法是加一个变量 times 用于记录执行了多少次while循环,当times达到一个很大的值而依旧没有停止程序,可以判断出现了死循环,干脆直接输出结果,不再计算。
第二个问题产生的效果图如下
对于右边的样本集,我们用肉眼观察很明显聚类应该如红框所示,但是使用K-Means聚类后得到的结果与预期差异较大,究其原因有很多种,具体解决办法是将阈值减小,以达到更加精确的聚类