独自完成的isodata算法!!!
欧耶!我真厉害
代码主体是借鉴的这个链接:https://blog.csdn.net/zsiming/article/details/122410398,分裂合并逻辑也没有做更改,但是结合我的自身需要改了很多部分:主要改的是数据读取部分,我要将高高光谱数据一条一条或者一帧一帧的读取,不能一次性全部load进去数据。
全部代码见如下:
#https://blog.csdn.net/zsiming/article/details/122410398
import numpy as np
import seaborn as sns
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.metrics import euclidean_distances
from read import *
from loadonelinedata import *
import numpy
from PyQt5.QtGui import QColor,QImage,QPixmap
class ISODATA():#定义类
def __init__( self,designCenterNum, LeastSampleNum, StdThred, LeastCenterDist, iterationNum):#定义完类后必须用__init__进行初始化,且必须是进入类后第一步执行的,执行完成后返回主函数
# 指定预期的聚类数、每类的最小样本数、标准差阈值、最小中心距离、迭代次数
self.K = designCenterNum
self.thetaN = LeastSampleNum
self.thetaS = StdThred
self.thetaC = LeastCenterDist
self.iteration = iterationNum
self.DataName = "Camera1refClip.raw"
self.Hdr = LoadHdr(self.DataName)
self.sampels = self.Hdr['samples']
self.lines = self.Hdr['lines']
self.bands = self.Hdr['bands']
self.t = 0.000031 #需要用户指定,范围为0-1之间
firstPoint = LoadPointWaves(self.DataName, self.Hdr, 0, 0) # 是 bands*1 的矩阵
self.center = firstPoint.reshape(1,-1)
self.centerNum = 1
self.centerMeanDist = 0
def updateLabel(self):
"""
更新中心
"""
self.classes = np.ones((self.sampels, self.lines)) * (-1)
for j in range(self.lines):
self.data = LoadOneLineData(self.DataName, self.Hdr, j) # 一次性加载一个samples * bands的矩阵
distance = np.zeros((self.centerNum,self.sampels))
for i in range(self.centerNum): # i为类别数
distance[i, :] = np.linalg.norm(self.data - self.center[i,:],axis=1) # linalg.norm代表差值平方后再加和再开根号。产生i*sample的矩阵(每更新一次j,就要重新覆盖一次distance)。此处减法为广播机制
self.classes[:, j] = np.argmin(distance,axis=0) # 给出最小值的下标;axis = 0代表对纵轴进行查找。产生sample*lines的矩阵。储存的是满足条件的点的行数(行列数均从0开始计算)
# 找出相同类的样本
# np.unique(self.classes)
self.classes = self.classes.astype(int)#将其转为整形
self.centerNum=len(np.unique(self.classes))
self.point = np.unique(self.classes)#找到不同的元素
self.centerNumList = self.point.astype(int).tolist()#将不同的元素转为list格式,为j的循环做准备
# newpoints = [[] for i in range(self.centerNum)] # 产生classNum个[]
newpoints = [[] for i in range(max(self.centerNumList) + 1)]
for i in range(self.bands):
self.data = LoadOneWaveData(self.DataName, self.Hdr, i) # 加载某一个波段的灰度图
for j in self.centerNumList:
newpoints[j].append(np.mean(self.data[np.where(self.classes == j)])) # ????where找到=j的点,返回对应值。将所有被分为j类的点取平均值,该平均值为其新的中心点
# 将newpoints转为数组,原list中的空值转为0值存入新数组中pointsSet
pointsSet = np.zeros((max(self.centerNumList) + 1,self.bands))
for i, j in enumerate(newpoints):
pointsSet[i][0:len(j)] = j
# pointsSet=pointsSet[pointsSet!= 0]
# pointsSet = pointsSet.reshape(self.centerNum, -1)
self.center = pointsSet # 更新中心(包含0元素数组)
# 计算所有类到各自中心的平均距离之和
# points = [[] for i in range(self.centerNum)] # 产生classNum个[]
for k in range(max(self.centerNumList) + 1):
# for k in range(self.centerNum):
points = [[] for i in range(max(self.centerNumList) + 1)] # 产生classNum个[]
for i in range(self.bands):
self.data = LoadOneWaveData(self.DataName, self.Hdr, i) # 加载某一条line的全部光谱(samples * bands的矩阵)
for j in self.centerNumList:
points[j].append(self.data[np.where(self.classes == j)])# 找出data中所有同一类样本!!!!!
# 计算样本到中心的距离
ppoints = numpy.array(points[k])
center = np.zeros((self.bands,1))
center= numpy.array(self.center[k]).reshape(-1,1)
distance = np.mean(np.linalg.norm(ppoints-center))
# 更新中心
self.centerMeanDist += distance
self.centerMeanDist /= self.centerNum
def divide(self):
# lines = self.Hdr['lines']
# bands = self.Hdr['bands']
# 临时保存更新后的中心集合,否则在删除和添加的过程中顺序会乱
# newCenterSet = self.center.reshape(self.centerNum,-1)
newCenterSet = self.center
for j in range(max(self.centerNumList) + 1):
points = [[] for i in range(max(self.centerNumList) + 1)] # 产生classNum个[]
indexs = np.where(self.classes == j) # [(),()]
stdEachDim = 0
for i in range(len(indexs[0])):
y = indexs[0][i]
x = indexs[1][i]
points[j]=(LoadPointWaves(self.DataName, self.Hdr, x, y))
# 计算每个类的样本在每个维度的标准差
ppoints = numpy.array(points[j]) #将list转为array进行运算
center = numpy.array(self.center).reshape(max(self.centerNumList) + 1,-1) #将list转为array进行运算
center = center[j,:].reshape(1,-1)
# 计算样本到中心每个维度的标准差
stdEachDim = stdEachDim+np.mean((ppoints-center)**2, axis=0) #axis = 0代表对纵轴进行查找
# 找出的最大标准差的纵坐标
a=1
# while stdEachDim.any()== 0 or a==2:
while np.any(stdEachDim == 0) or a == 2:
break
else:
maxIndex = np.argmax(stdEachDim) #返回标准差最大值的列坐标
maxStd = stdEachDim[maxIndex] #找到stdEachDim中最大值坐标的具体值
# 计算样本到本类中心的距离
distance = np.mean(np.linalg.norm(ppoints-center))
# 如果最大标准差超过了阈值
a=a+1
if maxStd > self.thetaS:
# 还需要该类的样本数大于阈值很多 且 太分散才进行分裂
if self.centerNum <= self.K//2 or \
ppoints.shape[0] > 2 * (self.thetaN+1) and distance >= self.centerMeanDist:
# center = np.zeros((self.bands))
# center = numpy.array(self.center[j])
newCenterFirst = center.copy()
newCenterSecond = center.copy()
newCenterFirst += self.t * maxStd
newCenterSecond -= self.t * maxStd
# 删除原始中心
newCenterSet = np.delete(newCenterSet, j, axis=0)
# 添加新中心
newCenterSet = np.vstack((newCenterSet, newCenterFirst))
newCenterSet = np.vstack((newCenterSet, newCenterSecond))
else:
break
# 更新中心集合
self.center = newCenterSet
self.centerNum = len(self.center)#shape[0]输出行数
def combine(self):
# 临时保存更新后的中心集合,否则在删除和添加的过程中顺序会乱
delIndexList = []
# 计算中心之间的距离
centerDist = euclidean_distances(self.center, self.center)
self.centershape=self.center.shape[0]
centerDist += (np.eye(self.centershape)) * 10**10 #eye生成对角线为1的对称矩阵
# 把中心距离小于阈值的中心对找出来
while True:
# 如果最小的中心距离都大于阈值的话,则不再进行合并
minDist = np.min(centerDist)
if minDist >= self.thetaC:
break
# 否则合并(此部分是找到中心点间距离最小的一对中心点。该对是label值为row值和col值的两个中心)
index = np.argmin(centerDist) #将数据展成一行,找到最小值的列数(列数从0开始数)
row = index // self.centershape #符号//表示取除法结果的整数部分
col = index % self.centershape #符号%表示取余数
# 找出合并的两个类别
index = np.argwhere(self.classes == row).squeeze()
classNumFirst = len(index)
index = np.argwhere(self.classes == col).squeeze()
classNumSecond = len(index)
a=1
while classNumFirst == 0 or classNumSecond == 0:
delIndexList.append(row)
delIndexList.append(col)
self.centerNum -= 1
self.center = np.delete(self.center, [row,col], axis=0)
centerDist[row, :] = float("inf")
centerDist[col, :] = float("inf")
centerDist[:, col] = float("inf")
centerDist[:, row] = float("inf")
break
else:
a=a+1
newCenter = self.center[row, :] * (classNumFirst / (classNumFirst+ classNumSecond)) + \
self.center[col, :] * (classNumSecond / (classNumFirst+ classNumSecond))
# 记录被合并的中心
delIndexList.append(row)
delIndexList.append(col)
# 增加合并后的中心
self.center = np.vstack((self.center, newCenter))
self.centerNum -= 1
# 标记,以防下次选中
centerDist[row, :] = float("inf")
centerDist[col, :] = float("inf")
centerDist[:, col] = float("inf")
centerDist[:, row] = float("inf")
# 更新中心
self.center = np.delete(self.center, delIndexList, axis=0)
self.centerNum = self.center.shape[0] #shape[0]为矩阵的行数
def showResult(self):
for j in self.centerNumList:
indexs = np.where(self.classes == j)
indexs= numpy.array(indexs)
x_value = indexs[0,:]
y_value = indexs[1,:]
plt.scatter(x_value, y_value, label="A产品")
plt.show()
def train(self):
# 初始化中心和label
self.updateLabel()
# self.drawResult()
self.showResult()
# 到设定的次数自动退出
for i in range(self.iteration):
# 如果是偶数次迭代或者中心的数量太多,那么进行合并
if self.centerNum < self.K //2:
self.divide()
elif (i > 0 and i % 2 == 0) or self.centerNum > 2 * self.K:
self.combine()
else:
self.divide()
# 更新中心
self.updateLabel()
# self.drawResult()
self.showResult()
print("中心数量:{}".format(self.centerNum))
if __name__ == "__main__":
isoData = ISODATA(designCenterNum=5, LeastSampleNum=10, StdThred=0.01, LeastCenterDist=2, iterationNum=20)#调用ISODATA类(用户指定预期的聚类数、每类的最小样本数、标准差阈值、最小中心距离、迭代次数)
isoData.train()#运行完第一步后进行该步骤
这个算法需要科研工作者对自身的数据有很多的了解,因为需要用户指定:每类的最小样本数、标准差阈值、最小中心距离、一个系数t。
系数t决定了每次分裂后新产生的center是否合理,如果不合理会导致永远无法找到正确的新center,会进行合并、分裂、合并…的死循环,而分类结果没有任何提升
每类的最小样本数用户应该心里大致有数
标准差阈值用户也应该心里大致有数
最小中心距离这个很麻烦,但是如果参数设置偏差太多也会很麻烦。
总之,这让我对算法是什么产生了新的认识,这项工作很有意义!