kmeans聚类算法及matlab实现

一、kmeans聚类算法介绍

   kmeans算法是一种经典的无监督机器学习算法,名列数据挖掘十大算法之一。作为一个非常好用的聚类算法,kmeans的思想和实现都比较简单。kmeans的主要思想:把数据划分到各个区域(簇),使得数据与区域中心的距离之和最小。换个角度来说,kmeans算法把数据量化为聚类中心,其目标函数就是使量化过程中损失的“信息”最少。kmeans算法求解目标函数的过程也可以看做是EM(Expectation maximization)迭代优化。

二、kmeans目标函数及优化

定义:数据集 XN×D ,聚类中心 UK×D ,指示矩阵 rN×K
其中 rnk=1 如果第n个样本属于第k个聚类,否则 rnk=0 。kmeans 聚类算法的目标是让数据与相应的聚类中心的距离之和最小:

P=n=1Nk=1Krnkxnuk2

  迭代优化:未知变量是两个矩阵–聚类中心 UK×D 和指示矩阵 rN×K ,直接进行优化太困难(NP问题),所以需要迭代优化这两个变量来得到一个局部最优解。
1、固定聚类中心 UK×D ,优化指示矩阵 rN×K :如果第n个样本距离第k个聚类最近,则赋值 rnk=1 ,否则赋值 rnk=0
2、固定指示矩阵 rN×K ,优化聚类中心 UK×D :拆分目标函数,分开优化每一个聚类中心(聚类中心之间没有耦合关系),推导如下
P=k=1Kn=1Nrnkxnuk2minn=1Nrnkxnuk2,k=1,...,K2n=1Nrnk(xnuk)=0uk=(n=1Nrnkxn)/(n=1Nrnk),k

  从EM角度来看:把指示矩阵 rN×K 当做隐变量,聚类中心 UK×D 当做模型参数,更新 rN×K UK×D 就分别对应于EM算法的E步和M步。

三、判停标准

kmeans算法的迭代优化过程一直持续直到满足某个判停标准,如果在这一轮迭代中:

1、训练样本所属类别不再发生改变或者只有很少几个训练样本改变;
2、目标函数变化很小或者聚类中心向量变化很小;
3、达到最大迭代次数。

满足其中一个条件,即可停止训练。如果满足条件1或2,说明算法已经收敛。

四、K值的选取

  随着聚类数K的增大,目标函数呈减小趋势。但是另一方面K值的增大会导致存储空间和计算量的增加。那么如何选择合适的K值呢?
1.经验法:根据问题的性质和先验知识,人为指定聚类的数目。
2.爬山法:但当聚类数目到达一定值以后,聚类数目的增加目标函数的变化很小,这个拐点可以认为是最优聚类数目。

kmeans聚类算法及matlab实现_第1张图片

五、matlab计算数据与聚类中心之间的距离

  matlab诞生的初衷就是为了矩阵运算的方便,所以如果用for循环来计算数据与聚类中心的距离从而得到指示矩阵 rN×K 这种速度极慢的方法是不可取的。对欧氏距离函数做个小小的变换:

xnuk2=(xnuk)(xnuk),n=1,...,N,k=1,...,K=xnxn+ukuk2xnuk=xn2+uk22xnuk

matlab代码(借鉴一下大牛的代码):

function n2 = sp_dist2(x, c)
% DIST2 Calculates squared distance between two sets of points.
% Adapted from Netlab neural network software:
% http://www.ncrg.aston.ac.uk/netlab/index.php
%
%   Description
%   D = DIST2(X, C) takes two matrices of vectors and calculates the
%   squared Euclidean distance between them.  Both matrices must be of
%   the same column dimension.  If X has M rows and N columns, and C has
%   L rows and N columns, then the result has M rows and L columns.  The
%   I, Jth entry is the  squared distance from the Ith row of X to the
%   Jth row of C.
%
%
%   Copyright (c) Ian T Nabney (1996-2001)

[ndata, dimx] = size(x);
[ncentres, dimc] = size(c);
if dimx ~= dimc
    error('Data dimension does not match dimension of centres')
end

n2 = (ones(ncentres, 1) * sum((x.^2)', 1))' + ...
  ones(ndata, 1) * sum((c.^2)',1) - ...
  2.*(x*(c'));

% Rounding errors occasionally cause negative entries in n2
if any(any(n2<0))
  n2(n2<0) = 0;
end

六、matlab实现kmeans

kmeans聚类算法实现

  
使用上述代码部分来完成kmeans聚类算法(sp_dist2函数):

function [centres, options, post, errlog] = sp_kmeans(centres, data, options)
% KMEANS    Trains a k means cluster model.
% Adapted from Netlab neural network software:
% http://www.ncrg.aston.ac.uk/netlab/index.php
%
%   Description
%    CENTRES = KMEANS(CENTRES, DATA, OPTIONS) uses the batch K-means
%   algorithm to set the centres of a cluster model. The matrix DATA
%   represents the data which is being clustered, with each row
%   corresponding to a vector. The sum of squares error function is used.
%   The point at which a local minimum is achieved is returned as
%   CENTRES.  The error value at that point is returned in OPTIONS(8).
%
%   [CENTRES, OPTIONS, POST, ERRLOG] = KMEANS(CENTRES, DATA, OPTIONS)
%   also returns the cluster number (in a one-of-N encoding) for each
%   data point in POST and a log of the error values after each cycle in
%   ERRLOG.    The optional parameters have the following
%   interpretations.
%
%   OPTIONS(1) is set to 1 to display error values; also logs error
%   values in the return argument ERRLOG. If OPTIONS(1) is set to 0, then
%   only warning messages are displayed.  If OPTIONS(1) is -1, then
%   nothing is displayed.
%
%   OPTIONS(2) is a measure of the absolute precision required for the
%   value of CENTRES at the solution.  If the absolute difference between
%   the values of CENTRES between two successive steps is less than
%   OPTIONS(2), then this condition is satisfied.
%
%   OPTIONS(3) is a measure of the precision required of the error
%   function at the solution.  If the absolute difference between the
%   error functions between two successive steps is less than OPTIONS(3),
%   then this condition is satisfied. Both this and the previous
%   condition must be satisfied for termination.
%
%   OPTIONS(14) is the maximum number of iterations; default 100.
%
%   Copyright (c) Ian T Nabney (1996-2001)

[ndata, data_dim] = size(data);
[ncentres, dim] = size(centres);

if dim ~= data_dim
  error('Data dimension does not match dimension of centres')
end

if (ncentres > ndata)
  error('More centres than data')
end

% Sort out the options
if (options(14))
  niters = options(14);
else
  niters = 100;
end

store = 0;
if (nargout > 3)
  store = 1;
  errlog = zeros(1, niters);
end

% Check if centres and posteriors need to be initialised from data
if (options(5) == 1)
  % Do the initialisation
  perm = randperm(ndata);
  perm = perm(1:ncentres);

  % Assign first ncentres (permuted) data points as centres
  centres = data(perm, :);
end
% Matrix to make unit vectors easy to construct
id = eye(ncentres);

% Main loop of algorithm
for n = 1:niters

  % Save old centres to check for termination
  old_centres = centres;

  % Calculate posteriors based on existing centres
  d2 = sp_dist2(data, centres);
  % Assign each point to nearest centre
  [minvals, index] = min(d2', [], 1);
  post = id(index,:);

  num_points = sum(post, 1);
  % Adjust the centres based on new posteriors
  for j = 1:ncentres
    if (num_points(j) > 0)
      centres(j,:) = sum(data(find(post(:,j)),:), 1)/num_points(j);
    end
  end

  % Error value is total squared distance from cluster centres
  e = sum(minvals);
  if store
    errlog(n) = e;
  end
  if options(1) > 0
    fprintf(1, 'Cycle %4d  Error %11.6f\n', n, e);
  end

  if n > 1
    % Test for termination
    if max(max(abs(centres - old_centres))) < options(2) & ...
        abs(old_e - e) < options(3)
      options(8) = e;
      return;
    end
  end
  old_e = e;
end

% If we get here, then we haven't terminated in the given number of 
% iterations.
options(8) = e;
if (options(1) >= 0)
  disp('Warning: Maximum number of iterations has been exceeded');
end

用法示例

options = zeros(1,14);
options(1) = 1; % display
options(2) = 1;
options(3) = 0.1; % precision
options(5) = 1; % initialization
options(14) = 100; % maximum iterations
data = random('Normal',0,1,10000,1000);
centres = zeros(100, 1000);
[centres, options, post, errlog] = sp_kmeans(centres, data, options);
Cycle    1  Error 17938493.427068
Cycle    2  Error 9830031.343386
Cycle    3  Error 9820962.377269
Cycle    4  Error 9817432.475919
Cycle    5  Error 9815699.935060
Cycle    6  Error 9814457.480926
Cycle    7  Error 9813788.116749
Cycle    8  Error 9813344.693653
Cycle    9  Error 9812973.717258
Cycle   10  Error 9812774.174316
Cycle   11  Error 9812574.730180
Cycle   12  Error 9812428.284545
Cycle   13  Error 9812320.224245
Cycle   14  Error 9812214.447799
Cycle   15  Error 9812136.019543
Cycle   16  Error 9812084.447704
Cycle   17  Error 9812069.131204
Cycle   18  Error 9812069.131204

你可能感兴趣的:(机器学习基础)