前言
前面的学习中,已经详细了解了K均值算法的相关原理。本篇文章,我们将使用python实现K均值算法并将其应用于图像压缩处理。这对K均值算法的直观理解是非常有帮助的。在本次算法实现中,利用K均值算法减少一副图像中图像颜色数量出现最多的部分像素,来实现图像压缩。
算法实现
K均值算法的主要原理就是给定一组初始数据,并初始化聚类中心,根据初始化的聚类中心,将的数据分配给最近的聚类中心,并重新计算新的聚类中心,一直重复这个过程,直到没有最新的数据分配给聚类中心或者是聚类中心不再发生新改变。
寻找聚类中心
K均值算法中,需要将每一个训练样本分配给最接近该样本的聚类中心,对于每一个训练样本,可以用如下公式求取其聚类中心:
其中,表示最接近的聚类中心索引,而则表示第个聚类中心的值或者位置,在代码中,用idx[i]
表示。
寻找聚类中心的的算法可以用以下代码实现:
- 初始化
首先,对一些参数进行初始化,如加载训练样本,聚类中心的数目和初始值设置,如下代码所示:
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as scio
from skimage import io
from skimage import img_as_float
#加载训练样本
data = scio.loadmat('ex7data2.mat')
X = data['X']
# 选择初始的聚类中心的数量和位置(value)
k = 3 #聚类中心的数目
initial_centroids = np.array([[3, 3], [6, 2], [8, 5]]) #聚类中心的初始值
- 寻找聚类中心
求取聚类中心的过程可以内外两层循环完成,内循环表示求取每一个聚类中心与样本的范数(距离),而外循环表示每一个样本减去聚类中心的所得的值。通过以上步骤,可以得到一个300X3的范数矩阵,最后,返回行方向上最小值索引(长度为300的一维数组)。
def find_closest_centroids(X, centroids):
K = centroids.shape[0] # K=3
m = X.shape[0] # m = 300
idx = np.zeros(m)
means = np.zeros((m, K))
for i in range(m):
x = X[i]
#外循环,每一个x的位置(二维矩阵)减去聚类中心的位置(二维矩阵)
diff = x - centroids
for k in range(K):
#内循环,x减去每一个聚类中心所得范数
means[i, k] = np.linalg.norm(diff[k])
#聚类中心的行方向上最小值所对应的索引
idx = np.argmin(means, axis=1)
return idx
注意:通过求取聚类中心,得到了由聚类中心索引所构成的一维数组。求取聚类中心之后,相当于给每个训练样本打上聚类中心的标识,而这一维数组的索引与训练样本的索引相对应,例如,300个训练样本,最后求得一个长度为300的一维数组,而其值表示训练样本所对应的聚类中心的标识。
计算聚类中心均值
以上,我们求取了行方向上最小聚类中心的索引,对于给定的个聚类中心,需要求得给第个聚类中心的所有训练样本的均值,可以用如下公式表示:
假设,有两个训练样本被分配给了聚类中心,则
计算聚类中心的均值的算法实现,如下代码所示
def compute_centroids(X, idx, K):
(m, n) = X.shape #m=300,n=2
centroids = np.zeros((K, n))
for k in range(K):
#每个聚类中心索引所对应的样本x,表示分配给聚类中心索引k的训练样本x
x_for_centroid_k = X[np.where(idx == k)]
#分配给索引为k的聚类中心的样本x在列方向上的和除以分配给聚类中心所对应的样本数量
centroid_k = np.sum(x_for_centroid_k, axis=0) / x_for_centroid_k.shape[0]
centroids[k] = centroid_k
return centroids
K均值算法的可视化实现
通过以上步骤,已经得到了训练样本的聚类中心和其均值,通过以下代码,通过十次迭代,可视化的实现K均值算法在训练样本中的运行方式,具体实现,如下代码所示:
def run_kmeans(X, initial_centroids, max_iters, plot):
if plot:
plt.figure()
(m, n) = X.shape
K = initial_centroids.shape[0]
centroids = initial_centroids
previous_centroids = centroids
idx = np.zeros(m)
for i in range(max_iters):
print('K-Means iteration {}/{}'.format((i + 1), max_iters))
idx = find_closest_centroids(X, centroids)
if plot:
plot_progress(X, centroids, previous_centroids, idx, K, i)
previous_centroids = centroids
input('Press ENTER to continue')
centroids = compute_centroids(X, idx, K)
return centroids, idx
def plot_progress(X, centroids, previous, idx, K, i):
plt.scatter(X[:, 0], X[:, 1], c=idx, s=15)
plt.scatter(centroids[:, 0], centroids[:, 1], marker='x', c='black', s=25)
for j in range(centroids.shape[0]):
draw_line(centroids[j], previous[j])
plt.title('Iteration number {}'.format(i + 1))
def draw_line(p1, p2):
plt.plot(np.array([p1[0], p2[0]]), np.array([p1[1], p2[1]]), c='black', linewidth=1)
通过10次迭代后,运行图像,如下图所示:
利用K均值算法实现图像压缩
以上,已经详细了解并实现了K均值算法,在这部分内容中,将使用K均值算法来实现图像压缩,所谓图像压缩指的是在图像像素方面的处理。图像常用的编码方式为RGB编码,即用三基色(RED,GREEN,BLUE)表示图像颜色。每个像素由三个8位无符号二进制数(范围从0到255)表示其像素颜色,例如,一个像素的颜色可以用(220,101,25)
表示。给定的图像包含这数千种颜色,通过K均值算法,可以将其颜色的数量降至16种,从而实现图像压缩。
像素处理
图像的每个像素颜色即代表训练样本,通过K均值算法寻找16种颜色代表图像中的所有像素的颜色,即也就是寻找16个聚类中心,最后,将所有的像素颜色替换为16个聚类中心所对应的颜色。
- 图像加载和预处理
对于每一个像素,可以用一个三维矩阵表示,其中,第一维和第二维表示其所在位置,第三维代表其是蓝色,红色,或者绿色。例如一个矩阵,表示53行,44列所在的像素其颜色为3.
在此过程中,需要将图像转换为的矩阵,其中,m=像素的行×列。其实现过程可以用如下代码表示
image = io.imread('bird_small.png')
#将图像转换为浮点型数据
image = img_as_float(image)
img_shape = image.shape
X = image.reshape(img_shape[0] * img_shape[1], 3)
通过以上代码,将图像转化为(128×128,3)的二维矩阵
运行K均值算法处理图像
- 聚类中心的随机初始化
在运行算法之前,还需要对聚类中心,进行随机初始化的处理,其初始化过程也就是对位置进行初始化,具体实现方式如下代码所示
def kmeans_init_centroids(X, K):
centroids = np.zeros((K, X.shape[1]))
indices = np.random.randint(X.shape[0], size=K)
centroids = X[indices]
return centroids
根据之前的算法实现,现在,可以直接运行K均值算法了,其实现方式如下所示
K = 16 #设置聚类中心数量
max_iters = 10 #最大迭代次数
initial_centroids = kmeans_init_centroids(X, K)
centroids, idx = run_kmeans(X, initial_centroids, max_iters, False)
- 实现图像压缩
经过以上处理,图像压缩的实现步骤如下代码所示:
idx = find_closest_centroids(X, centroids)
X_recovered = centroids[idx]
# (128*128*3)
X_recovered = np.reshape(X_recovered, (img_shape[0], img_shape[1], 3))
plt.subplot(2, 1, 1)
plt.imshow(image)
plt.title('Original')
plt.subplot(2, 1, 2)
plt.imshow(X_recovered)
plt.title('Compressed, with {} colors'.format(K))
最后,经过处理过后的图像对比如下图所示,可以明显的看出,第二幅图像的质量有所降低。