基于KMeans算法的图像分割例子

文章目录

  • 一、理论基础
    • 1、KMeans算法
    • 2、图像分割
  • 二、实验过程
    • 1、图片
    • 2、实验步骤
    • 3、Python代码
      • (1)导包
      • (2)读取图像数据
      • (3)处理图像数据
      • (4)KMeans聚类
      • (5)可视化展示
      • (6)完整代码
  • 三、参考文献

一、理论基础

1、KMeans算法

请参考这里。

2、图像分割

图像分割(Image segmentation)技术是计算机视觉领域的一个重要的研究方向,是图像语义理解的重要一环。图像分割是指将图像分成若干具有相似性质的区域的过程,从数学角度来看,图像分割是将图像划分成互不相交的区域的过程。

二、实验过程

1、图片

本文以图1的彩色图(命名为“girl.jpg”)为例进行图像分割。
基于KMeans算法的图像分割例子_第1张图片

图1 girl.jpg

2、实验步骤

  1. 读取图像数据,并对数据进行处理
  2. 对图像数据矩阵进行KMeans聚类
  3. 输出图像,观察结果

3、Python代码

(1)导包

from sklearn.cluster import KMeans
from matplotlib.image import imread
import matplotlib.pyplot as plt

(2)读取图像数据

image = imread('girl.jpg')
image.shape

结果显示为:

(1200, 1200, 3)

(3)处理图像数据

对读取后的图像矩阵image进行处理,使之满足KMeans聚类所需的数据要求,即样本数×特征数。

X = image.reshape(-1,3)
X.shape

结果显示为:

(1440000, 3)

(4)KMeans聚类

为了比较簇个数对聚类结果的影响,本文分别设置了簇个数为10、8、6、4、2,比较最后分割后的结果。

segmented_imgs = []
n_colors = (10,8,6,4,2)
for n_cluster in n_colors:
    kmeans = KMeans(n_clusters=n_cluster,random_state=42).fit(X)
    segmented_img = kmeans.cluster_centers_[kmeans.labels_]
    segmented_imgs.append(segmented_img.reshape(image.shape))

(5)可视化展示

plt.figure(figsize=(12,8))
plt.subplot(231)
plt.imshow(image.astype('uint8'))
plt.title('Original image')

for idx,n_clusters in enumerate(n_colors):
    plt.subplot(232+idx)
    plt.imshow(segmented_imgs[idx].astype('uint8'))
    plt.title('{} colors'.format(n_clusters))

结果显示:
基于KMeans算法的图像分割例子_第2张图片

(6)完整代码

# 作者:心升明月
# 开发时间:2022/2/10 18:34
from sklearn.cluster import KMeans
from matplotlib.image import imread
import matplotlib.pyplot as plt

# 读取图像数据
image = imread('../data/girl.jpg')
# 处理图像数据
X = image.reshape(-1,3)

# KMeans聚类
segmented_imgs = []
n_colors = (10,8,6,4,2)
for n_cluster in n_colors:
    kmeans = KMeans(n_clusters=n_cluster,random_state=42).fit(X)
    segmented_img = kmeans.cluster_centers_[kmeans.labels_]
    segmented_imgs.append(segmented_img.reshape(image.shape))

# 可视化展示
plt.figure(1,figsize=(12,8))
plt.subplot(231)
plt.imshow(image.astype('uint8'))
plt.title('Original image')
for idx,n_clusters in enumerate(n_colors):
    plt.subplot(232+idx)
    plt.imshow(segmented_imgs[idx].astype('uint8'))
    plt.title('{} colors'.format(n_clusters))
# plt.savefig('result.png')
plt.show()

三、参考文献

[1] 唐宇迪. 跟着迪哥学Python数据分析与机器学习实战[M]. 北京: 人民邮电出版社, 2019: 346-352.
[2] 宁萌Julie. Python中”Clipping input data to the valid range for imshow with RGB data ”的问题解决[Z]. CSDN博客, 2021.

你可能感兴趣的:(机器学习,聚类,kmeans,图像分割)