K-Means算法是一种无监督的学习,根据事先给定的分类数K,将所有对象划分为K个簇,且簇内的中心采用簇内所有对象的均值计算而成。
引用Peter Harrington著,李锐等人翻译的《机器学习实战》一书中伪代码流程如下:
创建K个点作为初始质心(通常是随机选择)
当任意一个点的簇分配结果发生改变时
对数据集中的每个数据点
对每个质心
计算质心与数据点之间的距离
将数据点分配到距其最近的簇
对每一个簇,计算簇中所有点的均值并将均值作为质心
之后在阅读了《KNN与K-Means的区别》这篇博文后,其中这段话启发了我,并完成了K-Means算法的Matlab代码。
“(a)刚开始时是原始数据,杂乱无章,没有label,看起来都一样,都是绿色的。(b)假设数据集可以分为两类,令K=2,随机在坐标上选两个点,作为两个类的中心点。(c-f)演示了聚类的两种迭代。先划分,把每个数据样本划分到最近的中心点那一簇;划分完后,更新每个簇的中心,即把该簇的所有数据点的坐标加起来去平均值。这样不断进行”划分—更新—划分—更新”,直到每个簇的中心不在移动为止。”
在代码中有详细的注释,一为解释,二为提供思路。
主代码如下:
%思路来源:https://www.cnblogs.com/nucdy/p/6349172.html
%{
数据样本用圆点表示,每个簇的中心点用叉叉表示。
(a)刚开始时是原始数据,杂乱无章,没有label,看起来都一样,都是绿色的。
(b)假设数据集可以分为两类,令K=2,随机在坐标上选两个点,作为两个类的中心点。
(c-f)演示了聚类的两种迭代。先划分,把每个数据样本划分到最近的中心点那一簇;
划分完后,更新每个簇的中心,即把该簇的所有数据点的坐标加起来去平均值。
这样不断进行”划分—更新—划分—更新”,直到每个簇的中心不在移动为止。
%}
clc,close,clear all;
k=2;%k是分类数,为了方便其他文件采用CC_general方便,本脚本引入了变量k
k_rept=10;%循环次数tt上限,应该设大一些,方便中途达到标准
%第一类数据
mu1=[0 0]; %均值
S1=[0.3 0;0 0.35]; %协方差
data1=mvnrnd(mu1,S1,200); %产生高斯分布数据
%第二类数据
mu2=[1.25 1.25];
S2=[0.3 0;0 0.35];
data2=mvnrnd(mu2,S2,200);
%显示数据
plot(data1(:,1),data1(:,2),'+','MarkerSize',5);
hold on;
plot(data2(:,1),data2(:,2),'r+','MarkerSize',5);
grid on;
data=[data1;data2];
%随机产生2个点作为中心点
CC=zeros(k,2);
[CC,H]=CC_general(CC,k);
tt=1;
for tt=1:k_rept
%计算每个点距离中心点之间的距离,然后划分
data_n=size(data,1);%获取data的行数
dis=zeros(data_n,3);%每行第一个存储到第一个中心的距离,0默认是离中心点1近
dis(:,1)=sqrt((data(:,1)-CC(1,1)).^2+(data(:,2)-CC(1,2)).^2);
dis(:,2)=sqrt((data(:,1)-CC(2,1)).^2+(data(:,2)-CC(2,2)).^2);
for i=1:data_n
if dis(i,1)>dis(i,2)
dis(i,3)=1;
else if dis(i,1)==dis(i,2)
if i/2==0
dis(i,3)=1;
end
end
end
end
clear i
%检查中心点是否有效,可能会产生两个中心的偏于某一角,离所有点都很远的情况
%检查条件(data_n,1/3data_n)可以控制初始点的质量
dis_ck(tt,1)=sum(dis(:,3),1);
if (dis_ck(tt,1)==data_n)||(dis_ck(tt,1)==0)
fprintf(1,'本次中心点无效\n')
CC=zeros(k,2);
[CC,H]=CC_general(CC,k);
continue
end
n1=0;%存储data1的点数
n2=0;
for i=1:data_n
lab=dis(i,3);
if lab==1
n2=n2+1;
data2(n2,:)=data(i,:);
else if lab==0
n1=n1+1;
data1(n1,:)=data(i,:);
end
end
end
clear i
data1=data1(1:n1,:);%把已覆盖的留下,原先的删去
data2=data2(1:n2,:);
data=[data1;data2];
new(1,:)=mean(data1,1);
new(2,:)=mean(data2,1);
if isequal(new,CC)
fprintf(1,'已完成迭代,第%d次迭代确定了最终中心点为:\n',tt)
fprintf(1,'中心点1为:%f, %f\n',CC(1,1), CC(1,2))
fprintf(1,'中心点2为:%f, %f\n',CC(2,1), CC(2,2))
break
end
set(H,'Visible','off');%删去H中的全部,即删去图中原先的CC
CC(1,:)=new(1,:);
CC(2,:)=new(2,:);
H(1,1)=plot(CC(1,1),CC(1,2),'g.','MarkerSize',15);% ,'MarkerSize','10'
H(1,2)=text(CC(1,1)+0.2,CC(1,2),'1');
hold on
H(2,1)=plot(CC(2,1),CC(2,2),'g.','MarkerSize',15);% ,'MarkerSize','10'
H(2,2)=text(CC(2,1)+0.2,CC(2,2),'2');
hold on
end
生成中心点的代码 CC_general.m
function [CC,H] = CC_general(CC,k)
%CC_GENERAL 随机生成k个初始的中心点
% k是分类数,也是初始中心点数
CC(:,1)=-1+4*rand(1,2);%产生随机横坐标(-1~3),CC第一列填充随机数
CC(:,2)=-1+4*rand(1,2);%产生随机纵坐标
for i=1:k
H(i,1)=plot(CC(i,1),CC(i,2),'o','MarkerSize',10);% ,'MarkerSize','10'
CC_order=num2str(i);
H(i,2)=text(CC(i,1)+0.2,CC(i,2),CC_order);
hold on
end
end