**
**
随机森林是一个包含多个决策树的分类器,因其运算速度快、分类精度高、算法稳定等特点,被广泛应用到遥感图像的分类研究中。Scikit-Learn作为Python 编程语言的免费软件机器学习库,提供了对随机森林算法的支持,但没有提供针对遥感影像分类的相关函数。因此,本篇文章将为读者介绍利用Python及其扩展包Scikit-Learn对遥感影像进行随机森林分类的完整过程,包括:ShapeFile格式样本数据的读取、栅格数据读取和裁剪、利用Scikit-Learn的RandomForestClassifier模块进行样本训练和遥感影像分类。
直接执行命令pip install scikit-learn,所有依赖库都会自动安装。安装完成后,添加代码from sklearn.ensemble import RandomForestClassifier即可使用
在ArcGIS中绘制训练样本,格式为shpfile,可以是点类型或面类型,建立Value字段,用于存储分类编号,例如1-湿地,2-湖泊,3-水稻。
通过shp样本获取对应的栅格值,需要使用多边形裁剪栅格数据,我们使用射线算法。全部代码如下(TrainByRandomForest.py):
代码如下(示例):
from sklearn.ensemble import RandomForestClassifier
import osgeo
from osgeo import gdal
from osgeo import osr
from osgeo import ogr
def GetSubRaster(inraster,polygonPoints:list):
polygonPoints.append(polygonPoints[0])#面多边形坐标封闭
print("当前多边形节点数量:"+str(len(polygonPoints)))
#计算最小边界矩形
minX=10000000000000
maxX=-minX
minY=100000000000000000
maxY=-minY
for point in polygonPoints:
if point.X<minX:minX=point.X
if point.X>maxX:maxX=point.X
if point.Y<minY:minY=point.Y
if point.Y>maxY:maxY=point.Y
leftX=minX
upY=maxY
rightX=maxX
bottomY=minY
rds = gdal.Open(inraster)
transform = (rds.GetGeoTransform())
lX = transform[0]#左上角点
lY = transform[3]
rX = transform[1]#分辨率
rY = transform[5]
wpos=int((leftX-lX)/rX)
hpos=int((upY-lY)/rY)
width=int((rightX-leftX)/rX)
height=int((bottomY-upY)/rY)
BandsCount = rds.RasterCount
arr = rds.ReadAsArray(wpos,hpos,width,height)
fixX=list()
nodatavalue=rds.GetRasterBand(1).GetNoDataValue()
for i in range(height):
if height>200:
print(f"多边形裁剪进度:{round(((i+1)/height)*100,4)}%")
Y=upY+i*rY+.00001
#射线算法只需要比对多边形的一条水平线上的边
pointsindex=list()
for k in range(len(polygonPoints)-1):
point1=polygonPoints[k]
point2=polygonPoints[k+1]
if (point1.Y>=Y and point2.Y<=Y) or (point1.Y<=Y and point2.Y>=Y):
pointsindex.append(k)
for j in range(width):
count=0
for m in (pointsindex):
point1=polygonPoints[m]
point2=polygonPoints[m+1]
X=leftX+j*rX+.00001
if point1.X==point2.X:
intersectX=point1.X
if intersectX>X:count+=1
else:
k=(point2.Y-point1.Y)/(point2.X-point1.X)
if k==0:
if X<point1.X or X<point2.X:
count+=1
else:
intersectX=(Y-point1.Y)/k+point1.X
if intersectX>X:count+=1
if count%2==0:
if BandsCount>1:
for bc in range(BandsCount):
arr[bc][i][j]=(nodatavalue)
else:
arr[i][j]=-1
#为了测试结果的正确性,可以先将其写到硬盘
#WriteRaster("test.tif",arr,inraster,width,height,BandsCount,leftX,upY)
return arr,width,height,BandsCount,leftX,upY
代码如下(示例):
def createClassifier(inraster,inshp,field:str="Id",treenum:int=100):
rasterspatial = gdal.Open(inraster)
spatial2=osr.SpatialReference()
spatial2.ImportFromWkt(rasterspatial.GetProjectionRef())
shpspatial=ogr.Open(inshp)
layer=shpspatial.GetLayer(0)
spatial1=layer.GetSpatialRef()
ct=osr.CreateCoordinateTransformation(spatial1,spatial2)
oFeature = layer.GetNextFeature()
# 下面开始遍历图层中的要素
geom=oFeature.GetGeometryRef()
if geom.GetGeometryType()==ogr.wkbPoint:
return createClassifierByPoint(inraster,inshp)
k=geom.GetGeometryType()
if geom.GetGeometryType()!=ogr.wkbPolygon:
print("样本必须为单部件多边形")
return False
trainX = list()
trainY = list()
print("读取样本")
while oFeature is not None:
geom=oFeature.GetGeometryRef()
wkt=geom.ExportToWkt()
points=WKTToPoints(wkt)
polygonPoints=[]
value=oFeature.GetField(field)
for point in points:
pC=ct.TransformPoint(point.X,point.Y,0)
polygonPoints.append(Point(pC[0],pC[1]))
arr,width,height,BandsCount,leftX,upY=GetSubRaster(inraster,polygonPoints)
for i in range(height):
for k in range(width):
nodata=True
tem = list()
for bc in range(BandsCount):
v=int(arr[bc][i][k])
tem.append(v)
if v>0:nodata=False
if nodata:
continue
trainX.append(tem)
trainY.append(int(value))
oFeature = layer.GetNextFeature()
ct=None
spatial1=None
spatial2=None
print("训练样本")
clf = RandomForestClassifier(n_estimators=treenum)
clf.fit(trainX, trainY)#训练样本
print("训练完成")
return clf
from sklearn.ensemble import RandomForestClassifier
import osgeo
from osgeo import gdal
from osgeo import osr
from osgeo import ogr
import numpy
import os
import sys
import TrainByRandomForest as tbrf
def RandomForestClassification(ClassifyRaster,TrainRaster,TrainShp,outfile,blockSize=0,treenum=100,max_depth=10):
rds = gdal.Open(ClassifyRaster)
#print((rds.GetRasterBand(1).DataType))
transform = (rds.GetGeoTransform())
lX = transform[0]#左上角点
lY = transform[3]
rX = transform[1]#分辨率
rY = transform[5]
width = rds.RasterXSize
height = rds.RasterYSize
bX = lX + rX * width#右下角点
bY = lY + rY * height
BandsCount = rds.RasterCount
clf = tbrf.createClassifier(TrainRaster,TrainShp)
Z = list()
fixX = list()
if blockSize == 0:
p,a = memory_usage()
pv = 0.6 / 10000
checkMemory(2000)#内存小于2GB,不在计算
bl = (a - 2000) / pv / height / BandsCount
blockSize = math.ceil(height / bl)
if blockSize < 1:blockSize = 1
if blockSize > 1:blockSize+=5
if blockSize != 1:
blockHeight = 0
modHeight = 0
modHeight = height % blockSize
if modHeight == 0:
blockHeight = int(height / blockSize)
else:
blockHeight = int(height / blockSize)
print(f"分块大小{width}*{blockHeight}")
for bs in range(blockSize):
print(f"计算块{bs+1}/{blockSize}")
checkMemory(500)
arr = rds.ReadAsArray(0,bs * blockHeight,width,blockHeight)
for i in range(blockHeight):
print(f"分块:{bs+1}/{blockSize}添加分类数据{round((i+1)*100/blockHeight,4)}%")
for k in range(width):
tem = list()
for bc in range(BandsCount):
tem.append(int(arr[bc][i][k]))
fixX.append(tem)
print(f"分块:{bs+1}/{blockSize}计算分类结果……")
checkMemory(800)
z = clf.predict(fixX)
Z.extend(z)
fixX = list()
arr = None
print(f"计算余数:{width}*{modHeight}")
checkMemory(500)
arr = rds.ReadAsArray(0,blockSize * blockHeight,width,modHeight)
if modHeight > 0:
for i in range(modHeight):
print(f"余块:添加分类数据{round((i+1)*100/modHeight,4)}%")
for k in range(width):
tem = list()
for bc in range(BandsCount):
tem.append(int(arr[bc][i][k]))
fixX.append(tem)
print("余块:计算分类结果……")
checkMemory(500)
z = clf.predict(fixX)
Z.extend(z)
Z = numpy.array(Z)
#Z=Z.reshape(1,width*height)
Z = Z.reshape(height,width)
fixX = None
arr = None
else:
checkMemory(1000)
arr = rds.ReadAsArray(0,0,width,height)
for i in range(height):
print(f"添加训练样本{round((i+1)*100/height,4)}%")
for k in range(width):
tem = list()
for bc in range(BandsCount):
tem.append(int(arr[bc][i][k]))
fixX.append(tem)
arr = None
print("计算分类结果……")
Z = clf.predict(fixX)
Z = numpy.array(Z)
Z = Z.reshape(height,width)
driver = gdal.GetDriverByName("GTiff")
filepath,filename = os.path.split(outfile)
short,ext = os.path.splitext(filename)
print("创建输出文件")
out = driver.Create(outfile,width,height,1,rds.GetRasterBand(1).DataType)
out.SetGeoTransform(transform)
out.SetProjection(rds.GetProjectionRef())
print("写入数据……")
out.GetRasterBand(1).WriteArray(Z)
out.FlushCache()
out = None
print("计算完成")
分类图像和分类结果
完整数据和代码:https://pan.baidu.com/s/19-uDLKLVVZVjop-yiYovWQ
提取码:szta
视频教程地址请查看最上方