kmeans算法是一种经典的无监督机器学习算法,名列数据挖掘十大算法之一。作为一个非常好用的聚类算法,kmeans的思想和实现都比较简单。kmeans的主要思想:把数据划分到各个区域(簇),使得数据与区域中心的距离之和最小。换个角度来说,kmeans算法把数据量化为聚类中心,其目标函数就是使量化过程中损失的“信息”最少。kmeans算法求解目标函数的过程也可以看做是EM(Expectation maximization)迭代优化。
定义:数据集 XN×D ,聚类中心 UK×D ,指示矩阵 rN×K 。
其中 rnk=1 如果第n个样本属于第k个聚类,否则 rnk=0 。kmeans 聚类算法的目标是让数据与相应的聚类中心的距离之和最小:
kmeans算法的迭代优化过程一直持续直到满足某个判停标准,如果在这一轮迭代中:
1、训练样本所属类别不再发生改变或者只有很少几个训练样本改变;
2、目标函数变化很小或者聚类中心向量变化很小;
3、达到最大迭代次数。
满足其中一个条件,即可停止训练。如果满足条件1或2,说明算法已经收敛。
随着聚类数K的增大,目标函数呈减小趋势。但是另一方面K值的增大会导致存储空间和计算量的增加。那么如何选择合适的K值呢?
1.经验法:根据问题的性质和先验知识,人为指定聚类的数目。
2.爬山法:但当聚类数目到达一定值以后,聚类数目的增加目标函数的变化很小,这个拐点可以认为是最优聚类数目。
matlab诞生的初衷就是为了矩阵运算的方便,所以如果用for循环来计算数据与聚类中心的距离从而得到指示矩阵 rN×K 这种速度极慢的方法是不可取的。对欧氏距离函数做个小小的变换:
matlab代码(借鉴一下大牛的代码):
function n2 = sp_dist2(x, c)
% DIST2 Calculates squared distance between two sets of points.
% Adapted from Netlab neural network software:
% http://www.ncrg.aston.ac.uk/netlab/index.php
%
% Description
% D = DIST2(X, C) takes two matrices of vectors and calculates the
% squared Euclidean distance between them. Both matrices must be of
% the same column dimension. If X has M rows and N columns, and C has
% L rows and N columns, then the result has M rows and L columns. The
% I, Jth entry is the squared distance from the Ith row of X to the
% Jth row of C.
%
%
% Copyright (c) Ian T Nabney (1996-2001)
[ndata, dimx] = size(x);
[ncentres, dimc] = size(c);
if dimx ~= dimc
error('Data dimension does not match dimension of centres')
end
n2 = (ones(ncentres, 1) * sum((x.^2)', 1))' + ...
ones(ndata, 1) * sum((c.^2)',1) - ...
2.*(x*(c'));
% Rounding errors occasionally cause negative entries in n2
if any(any(n2<0))
n2(n2<0) = 0;
end
使用上述代码部分来完成kmeans聚类算法(sp_dist2函数):
function [centres, options, post, errlog] = sp_kmeans(centres, data, options)
% KMEANS Trains a k means cluster model.
% Adapted from Netlab neural network software:
% http://www.ncrg.aston.ac.uk/netlab/index.php
%
% Description
% CENTRES = KMEANS(CENTRES, DATA, OPTIONS) uses the batch K-means
% algorithm to set the centres of a cluster model. The matrix DATA
% represents the data which is being clustered, with each row
% corresponding to a vector. The sum of squares error function is used.
% The point at which a local minimum is achieved is returned as
% CENTRES. The error value at that point is returned in OPTIONS(8).
%
% [CENTRES, OPTIONS, POST, ERRLOG] = KMEANS(CENTRES, DATA, OPTIONS)
% also returns the cluster number (in a one-of-N encoding) for each
% data point in POST and a log of the error values after each cycle in
% ERRLOG. The optional parameters have the following
% interpretations.
%
% OPTIONS(1) is set to 1 to display error values; also logs error
% values in the return argument ERRLOG. If OPTIONS(1) is set to 0, then
% only warning messages are displayed. If OPTIONS(1) is -1, then
% nothing is displayed.
%
% OPTIONS(2) is a measure of the absolute precision required for the
% value of CENTRES at the solution. If the absolute difference between
% the values of CENTRES between two successive steps is less than
% OPTIONS(2), then this condition is satisfied.
%
% OPTIONS(3) is a measure of the precision required of the error
% function at the solution. If the absolute difference between the
% error functions between two successive steps is less than OPTIONS(3),
% then this condition is satisfied. Both this and the previous
% condition must be satisfied for termination.
%
% OPTIONS(14) is the maximum number of iterations; default 100.
%
% Copyright (c) Ian T Nabney (1996-2001)
[ndata, data_dim] = size(data);
[ncentres, dim] = size(centres);
if dim ~= data_dim
error('Data dimension does not match dimension of centres')
end
if (ncentres > ndata)
error('More centres than data')
end
% Sort out the options
if (options(14))
niters = options(14);
else
niters = 100;
end
store = 0;
if (nargout > 3)
store = 1;
errlog = zeros(1, niters);
end
% Check if centres and posteriors need to be initialised from data
if (options(5) == 1)
% Do the initialisation
perm = randperm(ndata);
perm = perm(1:ncentres);
% Assign first ncentres (permuted) data points as centres
centres = data(perm, :);
end
% Matrix to make unit vectors easy to construct
id = eye(ncentres);
% Main loop of algorithm
for n = 1:niters
% Save old centres to check for termination
old_centres = centres;
% Calculate posteriors based on existing centres
d2 = sp_dist2(data, centres);
% Assign each point to nearest centre
[minvals, index] = min(d2', [], 1);
post = id(index,:);
num_points = sum(post, 1);
% Adjust the centres based on new posteriors
for j = 1:ncentres
if (num_points(j) > 0)
centres(j,:) = sum(data(find(post(:,j)),:), 1)/num_points(j);
end
end
% Error value is total squared distance from cluster centres
e = sum(minvals);
if store
errlog(n) = e;
end
if options(1) > 0
fprintf(1, 'Cycle %4d Error %11.6f\n', n, e);
end
if n > 1
% Test for termination
if max(max(abs(centres - old_centres))) < options(2) & ...
abs(old_e - e) < options(3)
options(8) = e;
return;
end
end
old_e = e;
end
% If we get here, then we haven't terminated in the given number of
% iterations.
options(8) = e;
if (options(1) >= 0)
disp('Warning: Maximum number of iterations has been exceeded');
end
options = zeros(1,14);
options(1) = 1; % display
options(2) = 1;
options(3) = 0.1; % precision
options(5) = 1; % initialization
options(14) = 100; % maximum iterations
data = random('Normal',0,1,10000,1000);
centres = zeros(100, 1000);
[centres, options, post, errlog] = sp_kmeans(centres, data, options);
Cycle 1 Error 17938493.427068
Cycle 2 Error 9830031.343386
Cycle 3 Error 9820962.377269
Cycle 4 Error 9817432.475919
Cycle 5 Error 9815699.935060
Cycle 6 Error 9814457.480926
Cycle 7 Error 9813788.116749
Cycle 8 Error 9813344.693653
Cycle 9 Error 9812973.717258
Cycle 10 Error 9812774.174316
Cycle 11 Error 9812574.730180
Cycle 12 Error 9812428.284545
Cycle 13 Error 9812320.224245
Cycle 14 Error 9812214.447799
Cycle 15 Error 9812136.019543
Cycle 16 Error 9812084.447704
Cycle 17 Error 9812069.131204
Cycle 18 Error 9812069.131204