python最简单的聚类分析——遥感图像分类——最小距离分类

模式识别实验报告

  • 实验目的

目的:遥感图像分类。

  1. 利用envi专业遥感图像处理软件对遥感图像进行最小距离分类***************envi4.7经典版链接:https://pan.baidu.com/s/1bUx7Sym9jyd7ZpSSiu2QVw    提取码:h0pq
  2. 基于python编程实现对遥感图像的最小距离分类
  • 实验原理

最小距离分类法是分类器里面最基本的一种分类方法,它是通过求出未知类别向量X到事先已知的各类别(如A,B,C等等)中心向量的距离D,然后将待分类的向量X归结为这些距离中最小的那一类的分类方法

最小距离分类的原理:

在一个n维空间中,最小距离分类法首先计算每一个已知类别X(用向量表示是的各个维度的均值,形成形成一个均值 ,用向量表示A为类别的名称, 是类别A的样本特征集合是类别A的第1维特征集合,是第一维特征集合的均值,n为总的特征维数),同理,计算另一个类别(用向量表示是)的均值用向量表示,那么对于一个待分类的样本特征向量(用向量表示是),怎么判断它是属于类别还是呢?我们只需要分别计算的距离,以欧式距离为例,距离的计算公式如下:

然后找中的最小值,如果前者最小,那么X属于A类,如果后者小,那么X属于B类。

上面只是分了2类,在我下面的实验中分了6类,当然可以分成更多类。

 

  • 实验方案

(一)ENVI实现

1、File→Open image file→选择3个波段,对应RGB波段,对应下图。

python最简单的聚类分析——遥感图像分类——最小距离分类_第1张图片

 

2、Basic Tools→Region Of Interest→ROI Tool,选择测试样本和训练样本

python最简单的聚类分析——遥感图像分类——最小距离分类_第2张图片

 

3、ClassificationSupervisedMinimum Distance

python最简单的聚类分析——遥感图像分类——最小距离分类_第3张图片python最简单的聚类分析——遥感图像分类——最小距离分类_第4张图片

                                 python最简单的聚类分析——遥感图像分类——最小距离分类_第5张图片

                             python最简单的聚类分析——遥感图像分类——最小距离分类_第6张图片

 

 

4、计算混淆矩阵,精度评估,Classification→Post Classification→Confusion Matrix

python最简单的聚类分析——遥感图像分类——最小距离分类_第7张图片python最简单的聚类分析——遥感图像分类——最小距离分类_第8张图片

python最简单的聚类分析——遥感图像分类——最小距离分类_第9张图片

 

 

(二)Python编程实现

下面我们来谈谈最小距离分类法的一般步骤,说是最小距离分类器的步骤,其实是我们做监督分类基本的几个步骤。

1、确定类别m,并提取每一类所对应的已知的样本

2、从样本中提取出一些可以作为区分不同类别的特性,也就是我们通常所说的特征提取,如果提取出了n个不同的特性,那么我们就叫它n维空间,特征提取对分类的精度有重大的影响

3、分别计算每一个类别的样本所对应的特征,每一类的每一维都有特征集合,通过集合,可以计算出一个均值,也就是特征中心。

4、通常为了消除不同特征因为量纲不同的影响,我们对每一维的特征,需要做一个归一化,或者是放缩到(-1,1)等区间,使其去量纲化

5、利用选取的距离准则,对待分类的本进行判定。

 

Python代码:

说明:

  1. 运行环境:python3.7;
  2. 主要用gdal、numpy模块。
import gdal
import numpy as np
import os

class Dataset:
    def __init__(self, in_file):
        self.in_file = in_file  # Tiff或者ENVI文件
        dataset = gdal.Open(self.in_file)
        self.XSize = dataset.RasterXSize  # 网格的X轴像素数量307
        self.YSize = dataset.RasterYSize  # 网格的Y轴像素数量250
        self.GeoTransform = dataset.GetGeoTransform()  # 投影转换信息
        self.ProjectionInfo = dataset.GetProjection()  # 投影信息

    def get_data(self):
        dataset = gdal.Open(self.in_file)
        # BBand = dataset.GetRasterBand(band)
        im_geotrans = dataset.GetGeoTransform()  # 仿射矩阵
        im_proj = dataset.GetProjection()  # 地图投影信息
        self.data = dataset.ReadAsArray(0, 0, self.XSize, self.YSize).astype(np.float32)
        data=self.data
        return im_geotrans,im_proj,data

    def get_local(self,classcount):
        X=[[] for j in range(classcount)]
        i=1
        # print(type(self.data[0][0]),type(i))
        while i <= classcount:
            for y in range(0, self.YSize):
                for x in range(0, self.XSize):
                        if int(self.data[y][x])==i:
                            X[i-1].append([y,x])
            i=i+1
        return X
#######################################################################
    def get_duiying_local(self,data,XX):#找到对应像素出的波值求均值
        avg=0
        sixclass=[]
        for dat in data:
            for Xx in XX:
                y=Xx[0]
                x=Xx[1]
                avg = avg + dat[y][x]
            avg_sum=avg/len(XX)
            sixclass.append(avg_sum)
            avg=0.0
        return sixclass
#######################################################################
#######################################################################
    def cal_box(self,Xband6class,data):  # 计算矩阵求最小值
        piexl=[]
        pic=[[0 for i in range(self.XSize)] for j in range(self.YSize)]
        mat_list_class6 = np.array(Xband6class)
        for y in range(0, self.YSize):
            for x in range(0, self.XSize):
                for dat in data:
                    piexl.append(dat[y][x])
                piexl80=np.array(piexl)
                cha=piexl80-mat_list_class6
                mat_list_class6_tranpose=np.transpose(cha)
                dot_result=np.dot(mat_list_class6_tranpose,cha)
                # print(dot_result)
                pic[y][x]=dot_result
                piexl=[]
        return pic

    def select_min(self,pic_class):
        mins=[[0 for i in range(self.XSize)] for j in range(self.YSize)]
        six_6_class_list=[]
        for y in range(self.YSize):
            for x in range(self.XSize):
                for pic_class_single in pic_class:
                    six_6_class=pic_class_single[y][x]
                    six_6_class_list.append(six_6_class)
                min_class=min(six_6_class_list)
                # print(min_class)
                num_class=six_6_class_list.index(min_class)
                num_class=num_class+1
                mins[y][x]=num_class
                six_6_class_list=[]
        return mins
############################################################
    def writeimage(self,pic_rgb):
        # print(pic_rgb)
        pic=[[[0 for i in range(self.XSize)] for j in range(self.YSize)] for k in range(3)]
        #定义类别的颜色,我分为6类所以定义6种颜色,也可以定义为其它颜色,根据分类数定义类别#颜色的数量
        pic_rgb_red=[255,0,0]
        pic_rgb_green=[0,255,0]
        pic_rgb_blue=[0,0,255]
        pic_rgb_rg=[255,255,0]
        pic_rgb_rb=[255,0,255]
        pic_rgb_gb=[0,255,255]

        for j in range(self.YSize):
            for i in range(self.XSize):
                if pic_rgb[j][i]==1:
                    pic[0][j][i]=pic_rgb_red[0]
                    pic[1][j][i]=pic_rgb_red[1]
                    pic[2][j][i]=pic_rgb_red[2]
                if pic_rgb[j][i]==2:
                    pic[0][j][i] = pic_rgb_green[0]
                    pic[1][j][i] = pic_rgb_green[1]
                    pic[2][j][i] = pic_rgb_green[2]
                if pic_rgb[j][i]==3:
                    pic[0][j][i] = pic_rgb_blue[0]
                    pic[1][j][i] = pic_rgb_blue[1]
                    pic[2][j][i] = pic_rgb_blue[2]
                if pic_rgb[j][i]==4:
                    pic[0][j][i] = pic_rgb_rg[0]
                    pic[1][j][i] = pic_rgb_rg[1]
                    pic[2][j][i] = pic_rgb_rg[2]
                if pic_rgb[j][i]==5:
                    pic[0][j][i] = pic_rgb_rb[0]
                    pic[1][j][i] = pic_rgb_rb[1]
                    pic[2][j][i] = pic_rgb_rb[2]
                if pic_rgb[j][i]==6:
                    pic[0][j][i] = pic_rgb_gb[0]
                    pic[1][j][i] = pic_rgb_gb[1]
                    pic[2][j][i] = pic_rgb_gb[2]
        return pic

    def write_img(self,filename,im_proj,im_geotrans,im_data):

        #gdal数据类型包括
        #gdal.GDT_Byte,
        #gdal .GDT_UInt16, gdal.GDT_Int16, gdal.GDT_UInt32, gdal.GDT_Int32,
        #gdal.GDT_Float32, gdal.GDT_Float64
        #判断栅格数据的数据类型
        if 'int8' in im_data.dtype.name:
            datatype = gdal.GDT_Byte
        elif 'int16' in im_data.dtype.name:
            datatype = gdal.GDT_UInt16
        else:
            datatype = gdal.GDT_Float32
        #判读数组维数
        if len(im_data.shape) == 3:
            im_bands, im_height, im_width = im_data.shape
        else:
            im_bands, (im_height, im_width) = 1,im_data.shape

        #创建文件
        driver = gdal.GetDriverByName("ENVI")            #数据类型必须有,因为要计算需要多大内存空间或者GTiff
        dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

        dataset.SetGeoTransform(im_geotrans)              #写入仿射变换参数
        dataset.SetProjection(im_proj)                    #写入投影

        if im_bands == 1:
            dataset.GetRasterBand(1).WriteArray(im_data)  #写入数组数据
        else:
            for i in range(im_bands):
                # print(im_data[i])
                dataset.GetRasterBand(i+1).WriteArray(im_data[i])
        del dataset
#####################################################################

#样本信息处理
def main1():
    dir_path = r"D:\Python_DATA\game project\模式识别\images"
    filename = "xuanlian.img"
    file_path = os.path.join(dir_path, filename)
    dataset = Dataset(file_path)
    dataset.get_data()
    classcount=6 #分为6类,当然也可以分为更多类,这个参数可以根据选择的训练样本的类数改变
    return dataset.get_local(classcount)

def main2(X):
    dir_path = r"D:\Python_DATA\game project\模式识别\images"
    filename = "PHI.tif"
    file_path = os.path.join(dir_path, filename)
    dataset = Dataset(file_path)
    geotrans,proj,data=dataset.get_data()
    pic_class=[]
    for XX in X:
        Xband6class = dataset.get_duiying_local(data, XX)
        pics=dataset.cal_box(Xband6class,data)
        pic_class.append(pics)
    mins=dataset.select_min(pic_class)
    impic=dataset.writeimage(mins)
    impic=np.array(impic)
    print(impic.shape)
    dataset.write_img("re_img_changzhou.img",proj,geotrans,impic)

if __name__=="__main__":
    X=main1()
    main2(X)

 

python最简单的聚类分析——遥感图像分类——最小距离分类_第10张图片

输出结果re_img_changzhou.img

 

  • 结果分析

 

(envi分类后结果)

python最简单的聚类分析——遥感图像分类——最小距离分类_第11张图片

 

(编程实现分类后结果)

python最简单的聚类分析——遥感图像分类——最小距离分类_第12张图片

由上两张图可以看出,编程实现的结果和envi分类的结果都是比较理想的。精度在85%以上,所以分类的效果是比较明显的。

 

  • 总结

最小距离分类法原理简单,容易理解,计算速度快,但是因为其只考虑每一类样本的均值,而不用管类别内部的方差(每一类样本的分布),也不用考虑类别之间的协方差(类别和类别之间的相关关系),所以分类精度不高,因此,一般不用它作为我们分类对精度有高要求的分类,但它可以在快速浏览分类概况中使用。

最小距离分类算法是比较简单的好理解的。但是在编程实现的过程中出现的问题是值得铭记的,以防在下一次出现类似的错误。

 

 

 

 

 

 

你可能感兴趣的:(不是爬虫,python聚类分析,最小距离分类,gdal读取遥感图像)