监督分类:SVM即支持向量机实现遥感影像监督分类(更新:添加机器学习模型存储、大影像划框拼接)

前面已经有一个版本了,但是影像太大内存顶不住,而且训练和预测没有分离,后面批量用这个不可能每次每张影像都训练了再预测,这次正好有需求,我就最后把这个整理一下,算是终版吧,以后也不会再花时间整这个了
这里放个植被提取的结果,仔细看应该比NDVI要好很多
结果:

监督分类:SVM即支持向量机实现遥感影像监督分类(更新:添加机器学习模型存储、大影像划框拼接)_第1张图片

# -*- coding: utf-8 -*-
from osgeo import ogr
from osgeo import gdal
from gdalconst import *
import os, sys, time
import numpy as np
from sklearn.svm import SVC
import pickle

#获取样本点及其领域的值
def getPixels(shp, img):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    ds = driver.Open(shp, 0)
    if ds is None:
        print('Could not open ' + shp)
        sys.exit(1)

    layer = ds.GetLayer()
    xValues = []
    yValues = []
    feature = layer.GetNextFeature()
    while feature:
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        xValues.append(x)
        yValues.append(y)
        feature = layer.GetNextFeature()

    gdal.AllRegister()
    ds = gdal.Open(img, GA_ReadOnly)
    if ds is None:
        print('Could not open image')
        sys.exit(1)

    rows = ds.RasterYSize
    cols = ds.RasterXSize
    bands = ds.RasterCount
    transform = ds.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]
    values = []
    for i in range(len(xValues)):
        x = xValues[i]
        y = yValues[i]
        xOffset = int((x - xOrigin) / pixelWidth)
        yOffset = int((y - yOrigin) / pixelHeight)
        s = str(int(x)) + ' ' + str(int(y)) + ' ' + str(xOffset) + ' ' + str(yOffset) + ' '
        pt = []
        for j in range(bands):
            band = ds.GetRasterBand(j + 1)
            data = band.ReadAsArray(xOffset-5, yOffset-5, 10, 10)#请务必注意这里,这是以点为中心去了10*10的影像块作为样本,这可以作为你调整效果的一个参数吧
            value = data
            value = value.flatten()
            pt.append(value)
        
        temp = []
        pt = array_change(pt, temp)
        values.append(pt)
    
    temp2 = []
    all_values = array_change(values, temp2)
    all_values = np.asarray(all_values)
    temp3 = []
    result_values = array_change2(all_values, temp3)
    result_values = np.asarray(result_values)
    return result_values

#这是SVM的训练部分,这里面最后把模型存储在pickle文件里了
def svm_train(classArray, img_arr, model_path):
    array_num = len(classArray)
    classArray = np.asarray(classArray)
    RGB_arr = classArray[0]
    for k in range(array_num-1):
        RGB_arr = np.concatenate((RGB_arr,classArray[k+1]),axis=0)
    
    label= np.array([])
    for h in range(array_num):
        array_l = classArray[h].shape[0]
        label = np.append(label,h*np.ones(array_l))
    
    if os.path.exists(model_path):
        pass
    else:
        svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000)
        svc.fit(RGB_arr,label)
        with open(model_path, 'wb') as f:
            pickle.dump(svc, f)
    
    return array_num

#SVM加载模型
def get_model(model_path):
    with open(model_path, 'rb') as f:
        svc = pickle.load(f)
    return svc

#SVM预测函数
def svm_predict(svc, img_arr, array_num, outPath): 
    img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])
    predict = svc.predict(img_reshape)  #类别结果
    # prob = svc.predict_proba(img_reshape) #概率结果,可以选择输出
    for j in range(array_num):
        lake_bool = predict == np.float(j)
        lake_bool = lake_bool[:,np.newaxis]
        lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool,lake_bool),axis=1)
        # lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool),axis=1)#注意一下这里,你影像要是4波段就是打开上面那句,3个波段打开就下面这句
        lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
        img_arr[lake_bool_4d] = np.float(j)

    img_arr = img_arr.transpose((2,1,0))
    img_arr = img_arr[0]
    return img_arr

#读遥感影像获取基本信息,后面还有一个,可以删除一个,但是代码得稍微改下,我懒就不改了,你们自己改吧
def read_img(filename):
    dataset=gdal.Open(filename)
    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize
    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)
    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data

#写影像文件
def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

def array_change(inlist, outlist):
    for i in range(len(inlist[0])):
        outlist.append([j[i] for j in inlist])
    return outlist

def array_change2(inlist, outlist):
    for ele in inlist:
        for ele2 in ele:
            outlist.append(ele2)
    return outlist

#获取影像基本信息
def getTifSize(tif):
    dataSet = gdal.Open(tif)
    width = dataSet.RasterXSize
    height = dataSet.RasterYSize
    bands = dataSet.RasterCount
    geoTrans = dataSet.GetGeoTransform()
    proj = dataSet.GetProjection()
    return width,height,bands,geoTrans,proj

#划框预测影像块,model是模型全路径,array_num是样本数据,获取方式看我调用的地方,tif1大影像路径,divisionSize划框的快大小,tempPath预测影像块的临时存储路径
def partDivisionForBoundary(model,array_num,tif1,divisionSize,tempPath):
    width,height,bands,geoTrans,proj = getTifSize(tif1)
    partWidth = partHeight = divisionSize

    if width % partWidth > 0 :
        widthNum = width // partWidth + 1
    else:
        widthNum =  width // partWidth
    if height % partHeight > 0:
        heightNum = height // partHeight +1
    else:
        heightNum = height // partHeight

    realName = os.path.split(tif1)[1].split(".")[0]
    tif1 = gdal.Open(tif1)
    for i in range(heightNum):
        for j in range(widthNum):
            startX = partWidth * j
            startY = partHeight * i
            if startX+partWidth<= width and startY+partHeight<=height:
                realPartWidth = partWidth
                realPartHeight = partHeight
            elif startX + partWidth > width and startY+partHeight<=height:
                realPartWidth = width - startX
                realPartHeight = partHeight
            elif startX+partWidth <= width  and startY+partHeight > height:
                realPartWidth = partWidth
                realPartHeight = height - startY
            elif startX + partWidth > width and startY+partHeight > height:
                realPartWidth = width - startX
                realPartHeight = height - startY

            outName = realName + '_i' +str(i)+ '_j' +str(j)+".tif"
            outPath = os.path.join(tempPath,outName)
            if not os.path.exists(outPath):
                driver = gdal.GetDriverByName("GTiff")
                outTif = driver.Create(outPath,realPartWidth,realPartHeight,1,gdal.GDT_Float32)#注意预测结果是单通道的,所以要写1
                outTif.SetGeoTransform(geoTrans)
                outTif.SetProjection(proj)
                data1 = tif1.ReadAsArray(startX,startY,realPartWidth,realPartHeight)
                data1 = data1.transpose((2,1,0))
                svmData = svm_predict(model,data1,array_num,outPath)
                outTif.GetRasterBand(1).WriteArray(svmData)

#影像拼接
def partStretch(tif1,divisionSize,outStratchPath,tempPath):
    width,height,bands,geoTrans,proj = getTifSize(tif1)
    # bands = 1
    partWidth = partHeight = divisionSize

    if width % partWidth > 0 :
        widthNum = width // partWidth + 1
    else:
        widthNum =  width // partWidth
    if height % partHeight > 0:
        heightNum = height // partHeight +1
    else:
        heightNum = height // partHeight

    realName = os.path.split(tif1)[1].split(".")[0]
    driver = gdal.GetDriverByName("GTiff")
    outTif = driver.Create(outStratchPath,width,height,1,gdal.GDT_Byte)#注意预测结果是单通道的,所以要写1
    if outTif!= None:
        outTif.SetGeoTransform(geoTrans)
        outTif.SetProjection(proj)
    for i in range(heightNum):
        for j in range(widthNum):

            startX = partWidth * j
            startY = partHeight * i

            if startX+partWidth<= width and startY+partHeight<=height:
                realPartWidth = partWidth
                realPartHeight = partHeight
            elif startX + partWidth > width and startY+partHeight<=height:
                realPartWidth = width - startX
                realPartHeight = partHeight
            elif startX+partWidth <= width  and startY+partHeight > height:
                realPartWidth = partWidth
                realPartHeight = height - startY
            elif startX + partWidth > width and startY+partHeight > height:
                realPartWidth = width - startX
                realPartHeight = height - startY

            partTifName = realName + '_i' +str(i)+ '_j' +str(j)+".tif"
            partTifPath = os.path.join(tempPath,partTifName)
            divisionImg = gdal.Open(partTifPath)
            for k in range(1):  #因为结果是单波段,这里不想删改了,以后用到多波段正好
                data1 = divisionImg.GetRasterBand(k+1).ReadAsArray(0,0,realPartWidth,realPartHeight)
                outPartBand = outTif.GetRasterBand(k+1)
                outPartBand.WriteArray(data1,startX,startY)


if __name__ == '__main__':
    img_p = 'E:/wangjiangxian_2017.tif' #大影像的位置
    shp_path = 'E:/change/wj/point/' #样本点位置,这个样本点文件夹是0.shp,1.shp...这样的矢量文件,对应每个类别,和预测结果的0,1...一一对应,不要乱写文件名
    temp_path = 'E:/temp2017/'  #临时文件存放路径
    model_path = 'E:/model/model2017.pickle' #模型路径
    re_path = "E:/wangjiangxian_2017_s.tif" #结果路径

    time1 = time.time()
    class_list = []
    for shp in os.listdir(shp_path):
        if shp[-4:] == '.shp':
            shp_full_path = os.path.join(shp_path, shp)
            class_type  = getPixels(shp_full_path, img_p)
            class_list.append(class_type)
            
    num = svm_train(class_list,img_p,model_path)
    svm = get_model(model_path)
    partDivisionForBoundary(svm,num,img_p,5000,temp_path)
    partStretch(img_p,5000,re_path,temp_path)
    time2 = time.time()
    print((time2-time1)/3600)

注意一下:使用的时候我们发现SVM这个算法的速度和影像大小呈非线性的关系,影像越小速度越快,虽然不是以几何级数的量增长但也有几十倍了,所以滑块不要太大,这样会很快,也不要太小,后面还有拼接

追加一部分机器学习模块SVM相关的资料(我搜的),一起了解一下:

SVC
SVC用于分类:支持向量分类,基于libsvm实现的,数据拟合的时间复杂度是数据样本的二次方,这使得他很难扩展到10000个数据集,当输入是多类别时(SVM最初是处理二分类问题的),通过一对一的方案解决,当然也有别的解决办法。
SVC参数说明如下:
C:惩罚项,float类型,可选参数,默认为1.0,C越大,即对分错样本的惩罚程度越大,因此在训练样本中准确率越高,但是泛化能力降低,也就是对测试数据的分类准确率降低。相反,减小C的话,容许训练样本中有一些误分类错误样本,泛化能力强。对于训练样本带有噪声的情况,一般采用后者,把训练样本集中错误分类的样本作为噪声。
kernel:核函数类型,str类型,默认为’rbf’。可选参数为:
• ‘linear’:线性核函数
• ‘poly’:多项式核函数
• ‘rbf’:径像核函数/高斯核
• ‘sigmod’:sigmod核函数
• ‘precomputed’:核矩阵。precomputed表示自己提前计算好核函数矩阵,这时候算法内部就不再用核函数去计算核矩阵,而是直接用你给的核矩阵,核矩阵需要为n*n的。
degree:多项式核函数的阶数,int类型,可选参数,默认为3。这个参数只对多项式核函数有用,是指多项式核函数的阶数n,如果给的核函数参数是其他核函数,则会自动忽略该参数。
gamma:核函数系数,float类型,可选参数,默认为auto。只对’rbf’ ,’poly’ ,’sigmod’有效。如果gamma为auto,代表其值为样本特征数的倒数,即1/n_features。
coef0:核函数中的独立项,float类型,可选参数,默认为0.0。只有对’poly’ 和,’sigmod’核函数有用,是指其中的参数c。
probability:是否启用概率估计,bool类型,可选参数,默认为False,这必须在调用fit()之前启用,并且会fit()方法速度变慢。
shrinking:是否采用启发式收缩方式,bool类型,可选参数,默认为True。
tol:svm停止训练的误差精度,float类型,可选参数,默认为1e^-3。
cache_size:内存大小,float类型,可选参数,默认为200。指定训练所需要的内存,以MB为单位,默认为200MB。
class_weight:类别权重,dict类型或str类型,可选参数,默认为None。给每个类别分别设置不同的惩罚参数C,如果没有给,则会给所有类别都给C=1,即前面参数指出的参数C。如果给定参数’balance’,则使用y的值自动调整与输入数据中的类频率成反比的权重。
verbose:是否启用详细输出,bool类型,默认为False,此设置利用libsvm中的每个进程运行时设置,如果启用,可能无法在多线程上下文中正常工作。一般情况都设为False,不用管它。
max_iter:最大迭代次数,int类型,默认为-1,表示不限制。
decision_function_shape:决策函数类型,可选参数’ovo’和’ovr’,默认为’ovr’。’ovo’表示one vs one,’ovr’表示one vs rest。
random_state:数据洗牌时的种子值,int类型,可选参数,默认为None。伪随机数发生器的种子,在混洗数据时用于概率估计。
NuSVC
NuSVC(Nu-Support Vector Classification.):核支持向量分类,和SVC类似,也是基于libsvm实现的,但不同的是通过一个参数空值支持向量的个数。
• nu:训练误差的一个上界和支持向量的分数的下界。应在间隔(0,1 ]。
• 其余同SVC
LinearSVC
LinearSVC(Linear Support Vector Classification):线性支持向量分类,类似于SVC,但是其使用的核函数是”linear“上边介绍的两种是按照brf(径向基函数计算的,其实现也不是基于LIBSVM,所以它具有更大的灵活性在选择处罚和损失函数时,而且可以适应更大的数据集,他支持密集和稀疏的输入是通过一对一的方式解决的。
LinearSVC 参数解释
C:目标函数的惩罚系数C,用来平衡分类间隔margin和错分样本的,default C = 1.0;
loss:指定损失函数
penalty :
dual :选择算法来解决对偶或原始优化问题。当nsamples>nfeaturesnsamples>nfeatures 时dual=false。
tol :(default = 1e - 3): svm结束标准的精度;
multi_class:如果y输出类别包含多类,用来确定多类策略, ovr表示一对多,“crammer_singer”优化所有类别的一个共同的目标 。如果选择“crammer_singer”,损失、惩罚和优化将会被被忽略。
fit_intercept :
intercept_scaling :
class_weight :对于每一个类别i设置惩罚系数C=classweight[i]∗CC=classweight[i]∗C,如果不给出,权重自动调整为 nsamples/(nclasses∗np.bincount(y))nsamples/(nclasses∗np.bincount(y))
verbose:跟多线程有关.

使用诀窍
• 避免数据复制: 对于 SVC, SVR, NuSVC 和 NuSVR, 如果数据是通过某些方法而不是用 C 有序的连续双精度,那它先会调用底层的 C 命令再复制。 您可以通过检查它的 flags 属性,来确定给定的 numpy 数组是不是 C 连续的。
• 对于 LinearSVC (和 LogisticRegression) 的任何输入,都会以 numpy 数组形式,被复制和转换为 用 liblinear 内部稀疏数据去表达(双精度浮点型 float 和非零部分的 int32 索引)。 如果您想要一个适合大规模的线性分类器,又不打算复制一个密集的 C-contiguous 双精度 numpy 数组作为输入, 那我们建议您去使用 SGDClassifier 类作为替代。目标函数可以配置为和 LinearSVC 模型差不多相同的。
• 内核的缓存大小: 在大规模问题上,对于 SVC, SVR, nuSVC 和 NuSVR, 内核缓存的大小会特别影响到运行时间。如果您有足够可用的 RAM,不妨把它的 缓存大小 设得比默认的 200(MB) 要高,例如为 500(MB) 或者 1000(MB)。
• 惩罚系数C的设置:在合理的情况下, C 的默认选择为 1 。如果您有很多混杂的观察数据, 您应该要去调小它。 C 越小,就能更好地去正规化估计。
• 支持向量机算法本身不是用来扩大不变性,所以 我们强烈建议您去扩大数据量. 举个例子,对于输入向量 X, 规整它的每个数值范围为 [0, 1] 或 [-1, +1] ,或者标准化它的为均值为0方差为1的数据分布。请注意, 相同的缩放标准必须要应用到所有的测试向量,从而获得有意义的结果。 请参考章节 预处理数据 ,那里会提供到更多关于缩放和规整。
• 在 NuSVC/OneClassSVM/NuSVR 内的参数 nu , 近似是训练误差和支持向量的比值。
• 在 SVC, ,如果分类器的数据不均衡(就是说,很多正例很少负例),设置 class_weight=’balanced’ 与/或尝试不同的惩罚系数 C 。
• 在拟合模型时,底层 LinearSVC 操作使用了随机数生成器去选择特征。 所以不要感到意外,对于相同的数据输入,也会略有不同的输出结果。如果这个发生了, 尝试用更小的 tol 参数。
• 使用由 LinearSVC(loss=’l2’, penalty=’l1’, dual=False) 提供的 L1 惩罚去产生稀疏解,也就是说,特征权重的子集不同于零,这样做有助于决策函数。 随着增加 C 会产生一个更复杂的模型(要做更多的特征选择)。可以使用 l1_min_c 去计算 C 的数值,去产生一个”null” 模型(所有的权重等于零)。

SVM算法最初是为二值分类问题设计的,当处理多类问题时,就需要构造合适的多类分类器。目前,构造SVM多类分类器的方法主要有两类:一类是直接法,直接在目标函数上进行修改,将多个分类面的参数求解合并到一个最优化问题中,通过求解该最优化问题“一次性”实现多类分类。这种方法看似简单,但其计算复杂度比较高,实现起来比较困难,只适合用于小型问题中;另一类是间接法,主要是通过组合多个二分类器来实现多分类器的构造,常见的方法有one-against-one和one-against-all两种。

a.一对多法(one-versus-rest,简称1-v-r SVMs)。训练时依次把某个类别的样本归为一类,其他剩余的样本归为另一类,这样k个类别的样本就构造出了k个SVM。分类时将未知样本分类为具有最大分类函数值的那类。

b.一对一法(one-versus-one,简称1-v-1 SVMs)。其做法是在任意两类样本之间设计一个SVM,因此k个类别的样本就需要设计k(k-1)/2个SVM。当对一个未知样本进行分类时,最后得票最多的类别即为该未知样本的类别。Libsvm中的多类分类就是根据这个方法实现的。

c.层次支持向量机(H-SVMs)。层次分类法首先将所有类别分成两个子类,再将子类进一步划分成两个次级子类,如此循环,直到得到一个单独的类别为止。

d.其他多类分类方法。除了以上几种方法外,还有有向无环图SVM(Directed Acyclic Graph SVMs,简称DAG-SVMs)和对类别进行二进制编码的纠错编码SVMs。

对c和d两种方法的详细说明可以参考论文《支持向量机在多类分类问题中的推广》

decision_function: 返回的是样本距离超平面的距离。二分类没什么好说的,对于多分类ovo,得到每对分类器的输出,n_class *(n_class - 1)/ 2个值。举个列子,

clf.decision_function(predict_this)
[[ 96.42193513 -11.13296606 111.47424538 -88.5356536 44.29272494 141.0069203 ]
对应的分类器是 [AB, AC, AD, BC, BD, CD]
所以我们得到每对分类器的结果[A, C, A, C, B, C]
例如,96.42193513 是正的,所以AB分离器得到的label是A
因为[A, C, A, C, B, C]中有3个C,得票最多,所以C就是整个多分类模型的预测label,这个就是使用predict得到的结果
而ovr,直接选择绝对值最大那个作为预测label
predict_proba: predict_proba涉及到Platt scaling,SVM中Platt scaling涉及到某些理论问题,如果一定要使用一个得分去表示,可以使用decision_function 去代替predict_proba

多核优化版本:增加了svm_train2和svm_predict2
参考链接:https://blog.csdn.net/hohaizx/article/details/79656496
你发现速度变慢了,其实是因为使用了parameters 参数,里面列举了很多种需要训练的情况,训练了很多次模型的原因,可以自己减少。另外注意这个n_jobs=12,n_jobs要比电脑核数要少才行

# -*- coding: utf-8 -*-
from osgeo import ogr
from osgeo import gdal
from gdalconst import *
import os, sys, time
import numpy as np
from sklearn.svm import SVC
import pickle
from sklearn.model_selection import GridSearchCV

def getPixels(shp, img):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    ds = driver.Open(shp, 0)
    if ds is None:
        print('Could not open ' + shp)
        sys.exit(1)

    layer = ds.GetLayer()

    xValues = []
    yValues = []
    feature = layer.GetNextFeature()
    while feature:
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        xValues.append(x)
        yValues.append(y)
        feature = layer.GetNextFeature()

    gdal.AllRegister()

    ds = gdal.Open(img, GA_ReadOnly)
    if ds is None:
        print('Could not open image')
        sys.exit(1)

    rows = ds.RasterYSize
    cols = ds.RasterXSize
    bands = ds.RasterCount

    transform = ds.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    values = []
    for i in range(len(xValues)):
        x = xValues[i]
        y = yValues[i]

        xOffset = int((x - xOrigin) / pixelWidth)
        yOffset = int((y - yOrigin) / pixelHeight)

        s = str(int(x)) + ' ' + str(int(y)) + ' ' + str(xOffset) + ' ' + str(yOffset) + ' '

        pt = []
        for j in range(bands):
            band = ds.GetRasterBand(j + 1)
            data = band.ReadAsArray(xOffset-5, yOffset-5, 10, 10)
            value = data
            value = value.flatten()
            pt.append(value)
        
        temp = []
        pt = array_change(pt, temp)
        values.append(pt)
    
    temp2 = []
    all_values = array_change(values, temp2)
    all_values = np.asarray(all_values)

    temp3 = []
    result_values = array_change2(all_values, temp3)
    result_values = np.asarray(result_values)
    return result_values


def svmDeal(classArray, img_arr, outPath, im_proj, im_geotrans):
    array_num = len(classArray)
    classArray = np.asarray(classArray)
    # array_l = classArray[0].shape[0]

    RGB_arr = classArray[0]
    for k in range(array_num-1):
        RGB_arr = np.concatenate((RGB_arr,classArray[k+1]),axis=0)
    
    label= np.array([])
    for h in range(array_num):
        array_l = classArray[h].shape[0]
        label = np.append(label,h*np.ones(array_l))

    img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])
    # svc = SVC(kernel='poly', degree=4, cache_size=1000, max_iter=100)
    svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000)
    # svc = SVC(C=0.8, kernel='poly', cache_size=1000)
    svc.fit(RGB_arr,label)
    predict = svc.predict(img_reshape)
    for j in range(array_num):
        lake_bool = predict == np.float(j)
        lake_bool = lake_bool[:,np.newaxis]
        # lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool,lake_bool),axis=1)
        lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool),axis=1)
        lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
        img_arr[lake_bool_4d] = np.float(j)

    img_arr = img_arr.transpose((2,1,0))
    img_arr = img_arr[0]
    write_img(outPath, im_proj, im_geotrans, img_arr)

def svm_train(classArray, img_arr, model_path):
    array_num = len(classArray)
    classArray = np.asarray(classArray)

    RGB_arr = classArray[0]
    for k in range(array_num-1):
        RGB_arr = np.concatenate((RGB_arr,classArray[k+1]),axis=0)
    
    label= np.array([])
    for h in range(array_num):
        array_l = classArray[h].shape[0]
        label = np.append(label,h*np.ones(array_l))
        
    if os.path.exists(model_path):
        pass
    else:
        # svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000)
        svc = SVC(C=0.8, kernel='linear', cache_size=1000)
        svc.fit(RGB_arr,label)
        with open(model_path, 'wb') as f:
            pickle.dump(svc, f)
    
    return array_num

def svm_train2(classArray, img_arr, model_path):
    array_num = len(classArray)
    classArray = np.asarray(classArray)

    RGB_arr = classArray[0]
    for k in range(array_num-1):
        RGB_arr = np.concatenate((RGB_arr,classArray[k+1]),axis=0)
    
    label= np.array([])
    for h in range(array_num):
        array_l = classArray[h].shape[0]
        label = np.append(label,h*np.ones(array_l))
    
    svc = SVC()
    parameters = [
        # {
        #     'C': [0.8, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
        #     'gamma': [0.00001, 0.0001, 0.001, 0.1, 1, 10, 100, 1000],
        #     'kernel': ['rbf']
        # },
        # {
        #     'C': [0.8, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
        #     'kernel': ['linear']
        # }
        {
            'C': [0.8, 1, 3],
            'gamma': [0.00001, 0.0001, 0.001],
            'kernel': ['rbf']
        },
        {
            'C': [0.8, 1, 3, 5],
            'kernel': ['linear']
        }
    ]

    if os.path.exists(model_path):
        pass
    else:
        # svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000)
        # svc = SVC(C=0.8, kernel='linear', cache_size=1000)
        clf = GridSearchCV(svc, parameters, cv=5, n_jobs=12)
        clf.fit(RGB_arr,label)
        print(clf.best_params_)
        with open(model_path, 'wb') as f:
            pickle.dump(clf, f)
    
    return array_num

def get_model(model_path):
    with open(model_path, 'rb') as f:
        clf = pickle.load(f)
    return clf

def svm_predict(svc, img_arr, array_num, outPath): 
    img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])
    predict = svc.predict(img_reshape)
    for j in range(array_num):
        lake_bool = predict == np.float(j)
        lake_bool = lake_bool[:,np.newaxis]
        # lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool,lake_bool),axis=1)
        lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool),axis=1)
        lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
        img_arr[lake_bool_4d] = np.float(j)

    img_arr = img_arr.transpose((2,1,0))
    img_arr = img_arr[0]
    # write_img(outPath, im_proj, im_geotrans, img_arr)
    return img_arr

def svm_predict2(clf, img_arr, array_num, outPath): 
    img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])
    best_model = clf.best_estimator_
    predict = best_model.predict(img_reshape)
    for j in range(array_num):
        lake_bool = predict == np.float(j)
        lake_bool = lake_bool[:,np.newaxis]
        # lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool,lake_bool),axis=1)
        lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool),axis=1)
        lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
        img_arr[lake_bool_4d] = np.float(j)

    img_arr = img_arr.transpose((2,1,0))
    img_arr = img_arr[0]
    # write_img(outPath, im_proj, im_geotrans, img_arr)
    return img_arr


def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset


def write_img_(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, gdal.GDT_Byte)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset



def array_change(inlist, outlist):
    for i in range(len(inlist[0])):
        outlist.append([j[i] for j in inlist])
    return outlist

def array_change2(inlist, outlist):
    for ele in inlist:
        for ele2 in ele:
            outlist.append(ele2)
    return outlist

def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
    out = np.zeros_like(bands).astype(np.float32)
    # a = 0
    # b = 65535
    a = img_min
    b = img_max
    # print(a, b)
    c = np.percentile(bands[:, :], lower_percent)
    d = np.percentile(bands[:, :], higher_percent)
    # x = d-c
    # if (x==0).any():
    #     t = 0
    # else:
    t = a + (bands[:, :] - c) * (b - a) / (d - c)
    t[t < a] = a
    t[t > b] = b
    out[:, :] = t
    return out

def getTifSize(tif):
    dataSet = gdal.Open(tif)
    width = dataSet.RasterXSize
    height = dataSet.RasterYSize
    bands = dataSet.RasterCount
    geoTrans = dataSet.GetGeoTransform()
    proj = dataSet.GetProjection()
    return width,height,bands,geoTrans,proj


def partDivisionForBoundary(model,array_num,tif1,divisionSize,tempPath):
    width,height,bands,geoTrans,proj = getTifSize(tif1)
    partWidth = partHeight = divisionSize

    if width % partWidth > 0 :
        widthNum = width // partWidth + 1
    else:
        widthNum =  width // partWidth
    if height % partHeight > 0:
        heightNum = height // partHeight +1
    else:
        heightNum = height // partHeight

    realName = os.path.split(tif1)[1].split(".")[0]

    tif1 = gdal.Open(tif1)
    for i in range(heightNum):
        for j in range(widthNum):

            startX = partWidth * j
            startY = partHeight * i

            if startX+partWidth<= width and startY+partHeight<=height:
                realPartWidth = partWidth
                realPartHeight = partHeight
            elif startX + partWidth > width and startY+partHeight<=height:
                realPartWidth = width - startX
                realPartHeight = partHeight
            elif startX+partWidth <= width  and startY+partHeight > height:
                realPartWidth = partWidth
                realPartHeight = height - startY
            elif startX + partWidth > width and startY+partHeight > height:
                realPartWidth = width - startX
                realPartHeight = height - startY

            outName = realName + '_i' +str(i)+ '_j' +str(j)+".tif"
            outPath = os.path.join(tempPath,outName)

            if not os.path.exists(outPath):

                driver = gdal.GetDriverByName("GTiff")
                outTif = driver.Create(outPath,realPartWidth,realPartHeight,1,gdal.GDT_Float32)
                outTif.SetGeoTransform(geoTrans)
                outTif.SetProjection(proj)

                data1 = tif1.ReadAsArray(startX,startY,realPartWidth,realPartHeight)
                data1 = data1.transpose((2,1,0))
                svmData = svm_predict(model,data1,array_num,outPath)
                outTif.GetRasterBand(1).WriteArray(svmData)
    return 1

def partStretch(tif1,divisionSize,outStratchPath,tempPath):

    width,height,bands,geoTrans,proj = getTifSize(tif1)
    # bands = 1
    partWidth = partHeight = divisionSize

    if width % partWidth > 0 :
        widthNum = width // partWidth + 1
    else:
        widthNum =  width // partWidth
    if height % partHeight > 0:
        heightNum = height // partHeight +1
    else:
        heightNum = height // partHeight

    realName = os.path.split(tif1)[1].split(".")[0]

    driver = gdal.GetDriverByName("GTiff")
    outTif = driver.Create(outStratchPath,width,height,1,gdal.GDT_Byte)
    if outTif!= None:
        outTif.SetGeoTransform(geoTrans)
        outTif.SetProjection(proj)
    for i in range(heightNum):
        for j in range(widthNum):

            startX = partWidth * j
            startY = partHeight * i

            if startX+partWidth<= width and startY+partHeight<=height:
                realPartWidth = partWidth
                realPartHeight = partHeight
            elif startX + partWidth > width and startY+partHeight<=height:
                realPartWidth = width - startX
                realPartHeight = partHeight
            elif startX+partWidth <= width  and startY+partHeight > height:
                realPartWidth = partWidth
                realPartHeight = height - startY
            elif startX + partWidth > width and startY+partHeight > height:
                realPartWidth = width - startX
                realPartHeight = height - startY

            partTifName = realName + '_i' +str(i)+ '_j' +str(j)+".tif"
            partTifPath = os.path.join(tempPath,partTifName)
            divisionImg = gdal.Open(partTifPath)
            for k in range(1):
                data1 = divisionImg.GetRasterBand(k+1).ReadAsArray(0,0,realPartWidth,realPartHeight)
                outPartBand = outTif.GetRasterBand(k+1)
                outPartBand.WriteArray(data1,startX,startY)


if __name__ == '__main__':
    img_p = 'xx/20181102.tif'
    shp_path = 'xx/point1/'
    temp_path = 'xx/temp/'
    model_path = 'xx/model/20181102model.pickle'
    re_path = "xx/20181102_c.tif"

    time1 = time.time()
    class_list = []
    for shp in os.listdir(shp_path):
        if shp[-4:] == '.shp':
            shp_full_path = os.path.join(shp_path, shp)
            class_type  = getPixels(shp_full_path, img_p)
            class_list.append(class_type)

    num = svm_train2(class_list, img_p, model_path)
    svm = get_model(model_path)
    
    partDivisionForBoundary(svm,num,img_p,5000,temp_path)
    partStretch(img_p,5000,re_path,temp_path)

    time2 = time.time()
    print((time2-time1)/3600)

你可能感兴趣的:(gdal,python,SVM,机器学习,支持向量机,分类)