原理在此,从原理篇中我们可以看到,K均值算法主要有三大部分——随机初始化、聚类划分、移动聚点。
使用randperm
函数,将 1 ∼ m 1\sim m 1∼m的数列随机打乱,取出前 K K K个作为初始聚点。
function centroids = kMeansInitCentroids(X, K)
%KMEANSINITCENTROIDS This function initializes K centroids that are to be
%used in K-Means on the dataset X
% centroids = KMEANSINITCENTROIDS(X, K) returns K initial centroids to be
% used with the K-Means on the dataset X
%
% You should return this values correctly
centroids = zeros(K, size(X, 2));
% ====================== YOUR CODE HERE ======================
% Instructions: You should set centroids to randomly chosen examples from
% the dataset X
%
% Initialize the centroids to be random examples
%Randomly reorder the indicies of examples
randidx = randperm(size(X,1));
% Take the first K examples
centroids = X(randidx(1:K),:);
% =============================================================
end
先假定所有点都属于第1类,然后对于每个点尝试第2到K类更新它的类别。计算距离可以采用矩阵乘法,注意这里的样本点和聚点坐标都是行向量。
function idx = findClosestCentroids(X, centroids)
%FINDCLOSESTCENTROIDS computes the centroid memberships for every example
% idx = FINDCLOSESTCENTROIDS (X, centroids) returns the closest centroids
% in idx for a dataset X where each row is a single example. idx = m x 1
% vector of centroid assignments (i.e. each entry in range [1..K])
%
% Set K
K = size(centroids, 1);
% You need to return the following variables correctly.
idx = ones(size(X,1), 1);
% ====================== YOUR CODE HERE ======================
% Instructions: Go over every example, find its closest centroid, and store
% the index inside idx at the appropriate location.
% Concretely, idx(i) should contain the index of the centroid
% closest to example i. Hence, it should be a value in the
% range 1..K
%
% Note: You can use a for-loop over the examples to compute this.
%
for i=1:size(X,1)
x=X(i,:);
for k=2:K
if ((x-centroids(k,:))*(x-centroids(k,:))')<((x-centroids(idx(i),:))*(x-centroids(idx(i),:))')
idx(i)=k;
end
end
end
% =============================================================
end
将聚点坐标更新为它所在聚类其他点的平均值处,将其他点坐标累加起来在除以点数就可以了。
function centroids = computeCentroids(X, idx, K)
%COMPUTECENTROIDS returns the new centroids by computing the means of the
%data points assigned to each centroid.
% centroids = COMPUTECENTROIDS(X, idx, K) returns the new centroids by
% computing the means of the data points assigned to each centroid. It is
% given a dataset X where each row is a single data point, a vector
% idx of centroid assignments (i.e. each entry in range [1..K]) for each
% example, and K, the number of centroids. You should return a matrix
% centroids, where each row of centroids is the mean of the data points
% assigned to it.
%
% Useful variables
[m n] = size(X);
% You need to return the following variables correctly.
centroids = zeros(K, n);
% ====================== YOUR CODE HERE ======================
% Instructions: Go over every centroid and compute mean of all points that
% belong to it. Concretely, the row vector centroids(i, :)
% should contain the mean of the data points assigned to
% centroid i.
%
% Note: You can use a for-loop over the centroids to compute this.
%
cnt=zeros(K,1);
for i=1:m
centroids(idx(i),:)=centroids(idx(i),:)+X(i,:);
cnt(idx(i))=cnt(idx(i))+1;
end
centroids=centroids./cnt;
% =============================================================
end
这里为了更好展现求解过程,没有用随机初始化,函数runkMeans.m
由吴恩达提供:
% Load an example dataset
load('ex7data2.mat');
% Settings for running K-Means
max_iters = 10;
For consistency, here we set centroids to specific values but in practice you want to generate them automatically, such as by setting them to be random examples (as can be seen in kMeansInitCentroids).
initial_centroids = [3 3; 6 2; 8 5];
% Run K-Means algorithm. The 'true' at the end tells our function to plot the progress of K-Means
figure('visible','on'); hold on;
plotProgresskMeans(X, initial_centroids, initial_centroids, idx, K, 1);
xlabel('Press ENTER in command window to advance','FontWeight','bold','FontSize',14)
[~, ~] = runkMeans(X, initial_centroids, max_iters, true);
set(gcf,'visible','off'); hold off;
function [centroids, idx] = runkMeans(X, initial_centroids, ...
max_iters, plot_progress)
%RUNKMEANS runs the K-Means algorithm on data matrix X, where each row of X
%is a single example
% [centroids, idx] = RUNKMEANS(X, initial_centroids, max_iters, ...
% plot_progress) runs the K-Means algorithm on data matrix X, where each
% row of X is a single example. It uses initial_centroids used as the
% initial centroids. max_iters specifies the total number of interactions
% of K-Means to execute. plot_progress is a true/false flag that
% indicates if the function should also plot its progress as the
% learning happens. This is set to false by default. runkMeans returns
% centroids, a Kxn matrix of the computed centroids and idx, a m x 1
% vector of centroid assignments (i.e. each entry in range [1..K])
%
% Set default value for plot progress
if ~exist('plot_progress', 'var') || isempty(plot_progress)
plot_progress = false;
end
% Plot the data if we are plotting progress
if plot_progress
figure;
hold on;
end
% Initialize values
[m n] = size(X);
K = size(initial_centroids, 1);
centroids = initial_centroids;
previous_centroids = centroids;
idx = zeros(m, 1);
% Run K-Means
for i=1:max_iters
% Output progress
fprintf('K-Means iteration %d/%d...\n', i, max_iters);
if exist('OCTAVE_VERSION')
fflush(stdout);
end
% For each example in X, assign it to the closest centroid
idx = findClosestCentroids(X, centroids);
% Optionally, plot progress here
if plot_progress
plotProgresskMeans(X, centroids, previous_centroids, idx, K, i);
previous_centroids = centroids;
fprintf('Press enter to continue.\n');
pause;
end
% Given the memberships, compute new centroids
centroids = computeCentroids(X, idx, K);
end
% Hold off if we are plotting progress
if plot_progress
hold off;
end
end
在matlab中运行,可以看到聚点不断移动以及其他点颜色不断更新的情况:
一张完整的图片,若其像素为 N × M N\times M N×M,每个像素点用RGB表示,那么一个点就需要储存三个 0 ∼ 255 0\sim 255 0∼255的无符号整数,消耗 3 × 8 = 24 3\times 8=24 3×8=24个字节,整张图片就需要 N × M × 24 N\times M\times 24 N×M×24个字节。如果将图片的颜色进行压缩,找出16个聚类,用每个聚类的平均色彩去代替原本的颜色,那么每个像素点只需要储存颜色的索引 1 ∼ 16 1\sim 16 1∼16( 0 ∼ 15 0\sim 15 0∼15),只需要 4 4 4个字节,再储存一下划分出的 16 16 16中颜色的信息,总占用的空间为 16 × 24 + N × M × 4 16\times 24+N\times M\times 4 16×24+N×M×4,基本上可以将图片大小压缩到原来的 1 6 \frac{1}{6} 61。
首先,用imread
函数将图片读入,此时图片被储存在 N × M × 3 N\times M\times 3 N×M×3的三维矩阵中。再用reshape
将其转换为 ( N × M ) × 3 (N\times M)\times3 (N×M)×3的二维矩阵,这样就成为我们熟悉的模式了,调用K均值算法找出聚类,用聚点的颜色替换整个聚类的颜色:
% Load an image of a bird
A = double(imread('bird_small.png'));
A = A / 255; % Divide by 255 so that all values are in the range 0 - 1
% Size of the image
img_size = size(A);
X = reshape(A, img_size(1) * img_size(2), 3);
K = 16;
max_iters = 10;
initial_centroids = kMeansInitCentroids(X, K);
% Run K-Means
[centroids, ~] = runkMeans(X, initial_centroids, max_iters);
% Find closest cluster members
idx = findClosestCentroids(X, centroids);
X_recovered = centroids(idx,:);
% Reshape the recovered image into proper dimensions
X_recovered = reshape(X_recovered, img_size(1), img_size(2), 3);
% Display the original image
figure;
subplot(1, 2, 1);
imagesc(A);
title('Original');
axis square
% Display compressed image side by side
subplot(1, 2, 2);
imagesc(X_recovered)
title(sprintf('Compressed, with %d colors.', K));
axis square