高光谱isodata算法,可实现一条一条像素点的读取

独自完成的isodata算法!!!
欧耶!我真厉害
代码主体是借鉴的这个链接:https://blog.csdn.net/zsiming/article/details/122410398,分裂合并逻辑也没有做更改,但是结合我的自身需要改了很多部分:主要改的是数据读取部分,我要将高高光谱数据一条一条或者一帧一帧的读取,不能一次性全部load进去数据。
全部代码见如下:

#https://blog.csdn.net/zsiming/article/details/122410398
import numpy as np
import seaborn as sns
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.metrics import euclidean_distances
from read import *
from loadonelinedata import *
import numpy
from  PyQt5.QtGui import QColor,QImage,QPixmap


class ISODATA():#定义类
    def __init__( self,designCenterNum, LeastSampleNum, StdThred, LeastCenterDist, iterationNum):#定义完类后必须用__init__进行初始化,且必须是进入类后第一步执行的,执行完成后返回主函数
        #  指定预期的聚类数、每类的最小样本数、标准差阈值、最小中心距离、迭代次数
        self.K = designCenterNum
        self.thetaN = LeastSampleNum
        self.thetaS = StdThred
        self.thetaC = LeastCenterDist
        self.iteration = iterationNum
        self.DataName = "Camera1refClip.raw"
        self.Hdr = LoadHdr(self.DataName)
        self.sampels = self.Hdr['samples']
        self.lines = self.Hdr['lines']
        self.bands = self.Hdr['bands']
        self.t = 0.000031      #需要用户指定,范围为0-1之间

        firstPoint = LoadPointWaves(self.DataName, self.Hdr, 0, 0)  # 是 bands*1 的矩阵

        self.center = firstPoint.reshape(1,-1)
        self.centerNum = 1
        self.centerMeanDist = 0

    def updateLabel(self):
        """
            更新中心
        """
        self.classes = np.ones((self.sampels, self.lines)) * (-1)
        for j in range(self.lines):
            self.data = LoadOneLineData(self.DataName, self.Hdr, j)  # 一次性加载一个samples * bands的矩阵
            distance = np.zeros((self.centerNum,self.sampels))
            for i in range(self.centerNum):  # i为类别数
                distance[i, :] = np.linalg.norm(self.data - self.center[i,:],axis=1)  # linalg.norm代表差值平方后再加和再开根号。产生i*sample的矩阵(每更新一次j,就要重新覆盖一次distance)。此处减法为广播机制
            self.classes[:, j] = np.argmin(distance,axis=0)  # 给出最小值的下标;axis = 0代表对纵轴进行查找。产生sample*lines的矩阵。储存的是满足条件的点的行数(行列数均从0开始计算)
        # 找出相同类的样本
        # np.unique(self.classes)
        self.classes = self.classes.astype(int)#将其转为整形
        self.centerNum=len(np.unique(self.classes))
        self.point = np.unique(self.classes)#找到不同的元素
        self.centerNumList = self.point.astype(int).tolist()#将不同的元素转为list格式,为j的循环做准备
        # newpoints = [[] for i in range(self.centerNum)]  # 产生classNum个[]
        newpoints = [[] for i in range(max(self.centerNumList) + 1)]
        for i in range(self.bands):
            self.data = LoadOneWaveData(self.DataName, self.Hdr, i)  # 加载某一个波段的灰度图
            for j in self.centerNumList:
                newpoints[j].append(np.mean(self.data[np.where(self.classes == j)]))  # ????where找到=j的点,返回对应值。将所有被分为j类的点取平均值,该平均值为其新的中心点
        # 将newpoints转为数组,原list中的空值转为0值存入新数组中pointsSet
        pointsSet = np.zeros((max(self.centerNumList) + 1,self.bands))
        for i, j in enumerate(newpoints):
            pointsSet[i][0:len(j)] = j
        # pointsSet=pointsSet[pointsSet!= 0]
        # pointsSet = pointsSet.reshape(self.centerNum, -1)
        self.center = pointsSet # 更新中心(包含0元素数组)

        # 计算所有类到各自中心的平均距离之和
        # points = [[] for i in range(self.centerNum)]  # 产生classNum个[]
        for k in range(max(self.centerNumList) + 1):
        # for k in range(self.centerNum):
            points = [[] for i in range(max(self.centerNumList) + 1)]  # 产生classNum个[]
            for i in range(self.bands):
                self.data = LoadOneWaveData(self.DataName, self.Hdr, i)  # 加载某一条line的全部光谱(samples * bands的矩阵)
                for j in self.centerNumList:
                    points[j].append(self.data[np.where(self.classes == j)])# 找出data中所有同一类样本!!!!!
            # 计算样本到中心的距离
            ppoints = numpy.array(points[k])
            center = np.zeros((self.bands,1))
            center= numpy.array(self.center[k]).reshape(-1,1)
            distance = np.mean(np.linalg.norm(ppoints-center))
            # 更新中心
            self.centerMeanDist += distance
        self.centerMeanDist /= self.centerNum

    def divide(self):
        # lines = self.Hdr['lines']
        # bands = self.Hdr['bands']
        # 临时保存更新后的中心集合,否则在删除和添加的过程中顺序会乱
        # newCenterSet = self.center.reshape(self.centerNum,-1)
        newCenterSet = self.center
        for j in range(max(self.centerNumList) + 1):
            points = [[] for i in range(max(self.centerNumList) + 1)]  # 产生classNum个[]
            indexs = np.where(self.classes == j)  # [(),()]
            stdEachDim = 0
            for i in range(len(indexs[0])):
                y = indexs[0][i]
                x = indexs[1][i]
                points[j]=(LoadPointWaves(self.DataName, self.Hdr, x, y))
                # 计算每个类的样本在每个维度的标准差
                ppoints = numpy.array(points[j])    #将list转为array进行运算
                center = numpy.array(self.center).reshape(max(self.centerNumList) + 1,-1)           #将list转为array进行运算
                center = center[j,:].reshape(1,-1)
                # 计算样本到中心每个维度的标准差
                stdEachDim = stdEachDim+np.mean((ppoints-center)**2, axis=0)  #axis = 0代表对纵轴进行查找
            # 找出的最大标准差的纵坐标
            a=1
            # while stdEachDim.any()== 0 or a==2:
            while np.any(stdEachDim == 0) or a == 2:
                break
            else:
                maxIndex = np.argmax(stdEachDim)     #返回标准差最大值的列坐标
                maxStd = stdEachDim[maxIndex]       #找到stdEachDim中最大值坐标的具体值
                # 计算样本到本类中心的距离
                distance = np.mean(np.linalg.norm(ppoints-center))
                # 如果最大标准差超过了阈值
                a=a+1
                if maxStd > self.thetaS:
                    # 还需要该类的样本数大于阈值很多 且 太分散才进行分裂
                    if self.centerNum <= self.K//2 or \
                            ppoints.shape[0] > 2 * (self.thetaN+1) and distance >= self.centerMeanDist:
                        # center = np.zeros((self.bands))
                        # center = numpy.array(self.center[j])
                        newCenterFirst = center.copy()
                        newCenterSecond = center.copy()

                        newCenterFirst += self.t * maxStd
                        newCenterSecond -= self.t * maxStd

                        # 删除原始中心
                        newCenterSet = np.delete(newCenterSet, j, axis=0)
                        # 添加新中心
                        newCenterSet = np.vstack((newCenterSet, newCenterFirst))
                        newCenterSet = np.vstack((newCenterSet, newCenterSecond))
                else:
                    break
        # 更新中心集合
        self.center = newCenterSet
        self.centerNum = len(self.center)#shape[0]输出行数

    def combine(self):
        # 临时保存更新后的中心集合,否则在删除和添加的过程中顺序会乱
        delIndexList = []

        # 计算中心之间的距离
        centerDist = euclidean_distances(self.center, self.center)
        self.centershape=self.center.shape[0]
        centerDist += (np.eye(self.centershape)) * 10**10  #eye生成对角线为1的对称矩阵
        # 把中心距离小于阈值的中心对找出来
        while True:
            # 如果最小的中心距离都大于阈值的话,则不再进行合并
            minDist = np.min(centerDist)
            if minDist >= self.thetaC:
                break
            # 否则合并(此部分是找到中心点间距离最小的一对中心点。该对是label值为row值和col值的两个中心)
            index = np.argmin(centerDist)            #将数据展成一行,找到最小值的列数(列数从0开始数)
            row = index // self.centershape            #符号//表示取除法结果的整数部分
            col = index % self.centershape            #符号%表示取余数
            # 找出合并的两个类别
            index = np.argwhere(self.classes == row).squeeze()
            classNumFirst = len(index)
            index = np.argwhere(self.classes == col).squeeze()
            classNumSecond = len(index)
            a=1
            while classNumFirst == 0 or classNumSecond == 0:
                delIndexList.append(row)
                delIndexList.append(col)
                self.centerNum -= 1
                self.center = np.delete(self.center, [row,col], axis=0)
                centerDist[row, :] = float("inf")
                centerDist[col, :] = float("inf")
                centerDist[:, col] = float("inf")
                centerDist[:, row] = float("inf")
                break
            else:
                a=a+1
                newCenter = self.center[row, :] * (classNumFirst / (classNumFirst+ classNumSecond)) + \
                            self.center[col, :] * (classNumSecond / (classNumFirst+ classNumSecond))
                # 记录被合并的中心
                delIndexList.append(row)
                delIndexList.append(col)
                # 增加合并后的中心
                self.center = np.vstack((self.center, newCenter))
                self.centerNum -= 1
                # 标记,以防下次选中
                centerDist[row, :] = float("inf")
                centerDist[col, :] = float("inf")
                centerDist[:, col] = float("inf")
                centerDist[:, row] = float("inf")

        # 更新中心
        self.center = np.delete(self.center, delIndexList, axis=0)
        self.centerNum = self.center.shape[0]   #shape[0]为矩阵的行数

    def showResult(self):
        for j in self.centerNumList:
            indexs = np.where(self.classes == j)
            indexs= numpy.array(indexs)
            x_value = indexs[0,:]
            y_value = indexs[1,:]
            plt.scatter(x_value, y_value, label="A产品")
        plt.show()


    def train(self):
        # 初始化中心和label
        self.updateLabel()
        # self.drawResult()
        self.showResult()

        # 到设定的次数自动退出
        for i in range(self.iteration):
            # 如果是偶数次迭代或者中心的数量太多,那么进行合并
            if self.centerNum < self.K //2:
                self.divide()
            elif (i > 0 and i % 2 == 0) or self.centerNum > 2 * self.K:
                self.combine()
            else:
                self.divide()
            # 更新中心
            self.updateLabel()
            # self.drawResult()
            self.showResult()
            print("中心数量:{}".format(self.centerNum))




if __name__ == "__main__":
    isoData = ISODATA(designCenterNum=5, LeastSampleNum=10, StdThred=0.01, LeastCenterDist=2, iterationNum=20)#调用ISODATA类(用户指定预期的聚类数、每类的最小样本数、标准差阈值、最小中心距离、迭代次数)
    isoData.train()#运行完第一步后进行该步骤

这个算法需要科研工作者对自身的数据有很多的了解,因为需要用户指定:每类的最小样本数、标准差阈值、最小中心距离、一个系数t。
系数t决定了每次分裂后新产生的center是否合理,如果不合理会导致永远无法找到正确的新center,会进行合并、分裂、合并…的死循环,而分类结果没有任何提升
每类的最小样本数用户应该心里大致有数
标准差阈值用户也应该心里大致有数
最小中心距离这个很麻烦,但是如果参数设置偏差太多也会很麻烦。

总之,这让我对算法是什么产生了新的认识,这项工作很有意义!

你可能感兴趣的:(算法,机器学习,python)