机器学习(4)--层次聚类(hierarchical clustering)基本原理及实现简单图片分类

关于层次聚类(hierarchical clustering)的基本步骤:
1、假设每个样本为一类,计算每个类的距离,也就是相似度
2、把最近的两个合为一新类,这样类别数量就少了一个
3、重新新类与各个旧类(去了那两个合并的类)之间的相似度;
4、循环重复2和3直到所有样本点都归为一类

这个计算的过程,相当于重构一个二叉树,只是这个过程,是从树叶-->树枝-->树干的构建过程

本例将以14张图片,做为样本,进行聚类,点击这里  下载图片样本

以下是使用我提供的图片库生成的分类结果,以及一张PS修后对代码中各变量的说明

机器学习(4)--层次聚类(hierarchical clustering)基本原理及实现简单图片分类_第1张图片

机器学习(4)--层次聚类(hierarchical clustering)基本原理及实现简单图片分类_第2张图片

当然,你也可以自己定义一个目录,程序会读取目录下所有JPG图片

如果你用了自己的图片,在代码中的一此数据的变化说明,可能和使用产生的数据不同了,

同时,本文的主要目的是层次聚类(hierarchical clustering)的基本步骤,对于图片的相似度的算法并不完善,效果也并不是十分理想,不过如果你使用自己从手机中导入的生活照,不同的场景大致还是能分类出来的

# -*- coding:utf-8 -*-

from PIL import ImageDraw,Image
import numpy as np
import os
import sys


nodeList = []#用于存储所有的节点,包含图片节点,与聚类后的节点
distance = {}#用于存储所有每两个节点的距离,数据格式{(node1.id,node2.id):30.0,(node2.id,node3.id):40.0}
class node:
    def __init__(self, data):
        '''每个样本及样本合并后节点的类
            data:接受两种格式,
            1、当为字符(string)时,是图片的地址,同时也表示这个节点就是图片
            2、合并后的类,传入的格式为(leftNode,rightNode) 即当前类表示合并后的新类,而对应的左右节点就是子节点
        '''
        self.id = len(nodeList)#设置一个ID,以nodeList当然长度为ID,在本例中ID本身没太大用处,只是如果看代码时,有时要看指向时有点用
        self.parent = None # 指向合并后的类
        self.pos = None#用于最后绘制节构图使用,赋值时为(x,y,w,h)格式
        if type(data) == type("") :
            '''节点为图片'''
            self.imgData = Image.open(data)
            self.left = None
            self.right = None 
            self.level = 0    #图片为最终的子节点,所有图片的层级都为0,设置层级是为了最终绘制结构图

            npTmp = np.array(self.imgData).reshape(-1,3) #将图片数据转化为numpy数据,shape为(高,宽,3),3为颜色通道
            npTmp = npTmp.reshape(-1,3)  #重新排列,shape为(高*宽,3)
            self.feature = npTmp.mean(axis=0)#计算RGB三个颜色通道均值

        else:
            '''节点为合成的新类'''
            self.imgData = None
            self.left = data[0]
            self.right = data[1]
            self.left.parent = self
            self.right.parent = self

            self.level = max(self.left.level,self.right.level) + 1 #层级为左右节高层级的级数+1
            self.feature = (self.left.feature + self.right.feature) / 2 #两类的合成一类时,就是左右节点的feature相加/2
            
        #计算该类与每个其他类的距离,并存入distance
        for x in nodeList:
            distance[(x,self)] = np.sqrt(np.sum((x.feature - self.feature) ** 2))

        nodeList.append(self)#将本类加入nodeList变量

    def drawNode(self,img,draw,vLineLenght):
        #绘制结构图
        if self.pos == None:return
        if self.left == None:
            #如果是图片
            self.imgData.thumbnail((self.pos[2], self.pos[3]))
            img.paste(self.imgData,(self.pos[0], self.pos[1]))
            draw.line((int(self.pos[0] + self.pos[2] / 2)
                 , self.pos[1] - vLineLenght
                 , int(self.pos[0] + self.pos[2] / 2)
                 , self.pos[1])
                , fill=(255, 0, 0))
        else:
            #如果不是图片
            draw.line((int(self.pos[0])
                 , self.pos[1]
                 , int(self.pos[0] + self.pos[2])
                 , self.pos[1])
                , fill=(255, 0, 0))

            draw.line((int(self.pos[0] + self.pos[2] / 2)
                    , self.pos[1]
                    , int(self.pos[0] + self.pos[2] / 2)
                    , self.pos[1] - self.pos[3])
                    , fill=(255, 0, 0))

def loadImg(path):
    '''path 图片目录,根据自己存的地方改写'''
    files = None
    try :
        files = os.listdir(path)
    except:
        print('未正确读取目录:' + path + ',图片目录,请根据自己存的地方改写,并保证没有hierarchicalResult.jpg,该文件为最后生成文件')
        return None
    for i in files:

        if os.path.splitext(i)[1].lower() == '.jpg' and os.path.splitext(i)[0].lower() != 'hierarchicalresult':

            fileName = os.path.join(path,i)
            node(fileName)
    return os.path.join(path,'hierarchicalResult.jpg')

def getMinDistance():
    '''从distance中过滤出未分类的结点,并读取最小的距离'''
    vars = list(filter(lambda x:x[0].parent == None and x[1].parent == None ,distance))
    minDist = vars[0]
    for x in vars:
        if minDist == None or distance[x] < distance[minDist]:
            minDist = x
    return minDist

def createTree():
    while len(list(filter(lambda x:x.parent == None ,nodeList))) > 1:#合并到最后时,只有一个类,只要有两个以上未合并,就循环
        minDist = getMinDistance()
        #创建非图片的节点,之所以把[1]做为左节点,因为绘图时的需要,
        #在不断的产生非图片节点时,在nodeList的后面的一般是新节点,但绘图时绘在左边
        node((minDist[1],minDist[0])) 
    return nodeList[-1]#最后一个插入的节点就是要节点


def run():
    root = createTree()#创建树结构

    #一句话的PYTON,实现二叉树的左右根遍历,通过通过遍历,进行排序后,取出图片,做为最底层的打印
    sortTree = lambda node:([] if node.left == None else sortTree(node.left)) + ([] if node.right == None else sortTree(node.right)) + [node]
    treeTmp = sortTree(root)
    treeTmp = list(filter(lambda x:x.left == None,treeTmp))#没有左节点的,即为图片

    thumbSize = 60 #缩略图的大小,,在60X60的小格内缩放
    thumbSpace = 20 #缩略图间距
    vLineLenght = 80 #上下节点,即每个level之间的高度

    imgWidth = len(treeTmp) * (thumbSize + thumbSpace)
    imgHeight = (root.level+1) * vLineLenght + thumbSize + thumbSpace*2
    img = Image.new('RGB', (imgWidth,imgHeight), (255, 255, 255))
    draw = ImageDraw.Draw(img)

    for item in enumerate(treeTmp):
        #为所有图片增加绘图数据
        x = item[0] * (thumbSize + thumbSpace) + thumbSpace / 2
        y = imgHeight - thumbSize - thumbSpace / 2 - ((item[1].parent.level - 1) * vLineLenght)
        w = item[1].imgData.width
        h = item[1].imgData.height
        if w > h:
            h = h / w * thumbSize
            w = thumbSize
        else:
            w = w / h * thumbSize
            h = thumbSize
            x+=(thumbSize - w) / 2
        item[1].pos = (int(x),int(y),int(w),int(h))
        item[1].drawNode(img,draw,vLineLenght)

    for x in range(1,root.level + 1):
        #为所有非图片增加绘图的数据
        items = list(filter(lambda i:i.level == x,nodeList))
        for item in items:
            x = item.left.pos[0] + (item.left.pos[2] / 2)
            w = item.right.pos[0] + (item.right.pos[2] / 2) - x
            y = item.left.pos[1] - (item.level - item.left.level) * vLineLenght
            h = ((item.parent.level if item.parent != None else item.level + 1) - item.level) * vLineLenght
            item.pos = (int(x),int(y),int(w),int(h))
            item.drawNode(img,draw,vLineLenght)
    img.save(resultFile)

resultFile = loadImg(r"E:\hierarchicalImgs")#读取数据,并返回最后结果要存储的文件名,目录根据自己存的位置进行修改
if resultFile != 'None':
    run()
    print("结构图生成成功,最终结构图存储于:" + resultFile)



你可能感兴趣的:(python,机器学习)