环境:python2或3都可以,有osgeo(gdal)(本机版本2.2.4)、scikit-learn(本机版本0.20.3)、numpy(本机版本1.15.0)就行
这个代码是用的4波段遥感影像,如果用作其他波段的稍微修改一下就行了,后面会标记出来。
数据位置:链接:https://pan.baidu.com/s/14i-ePeWm-gnIPSsrgHmnMw
提取码:qkgz
首先在图像上选取样本点,其实就是选取了图像的像素值(我用的arcgis选点,一个矢量文件代表一个类别,后面会提供测试数据),然后就改一下对应的路径就行了,以下是代码部分:
# -*- coding: utf-8 -*-
from osgeo import ogr
from osgeo import gdal
from gdalconst import *
import os, sys, time
import numpy as np
from sklearn.svm import SVC
def getPixels(shp, img):
driver = ogr.GetDriverByName('ESRI Shapefile')
ds = driver.Open(shp, 0)
if ds is None:
print('Could not open ' + shp)
sys.exit(1)
layer = ds.GetLayer()
xValues = []
yValues = []
feature = layer.GetNextFeature()
while feature:
geometry = feature.GetGeometryRef()
x = geometry.GetX()
y = geometry.GetY()
xValues.append(x)
yValues.append(y)
feature = layer.GetNextFeature()
gdal.AllRegister()
ds = gdal.Open(img, GA_ReadOnly)
if ds is None:
print('Could not open image')
sys.exit(1)
rows = ds.RasterYSize
cols = ds.RasterXSize
bands = ds.RasterCount
transform = ds.GetGeoTransform()
xOrigin = transform[0]
yOrigin = transform[3]
pixelWidth = transform[1]
pixelHeight = transform[5]
values = []
for i in range(len(xValues)):
x = xValues[i]
y = yValues[i]
xOffset = int((x - xOrigin) / pixelWidth)
yOffset = int((y - yOrigin) / pixelHeight)
s = str(int(x)) + ' ' + str(int(y)) + ' ' + str(xOffset) + ' ' + str(yOffset) + ' '
pt = []
for j in range(bands):
band = ds.GetRasterBand(j + 1)
data = band.ReadAsArray(xOffset-5, yOffset-5, 10, 10) #取了以矢量点为中心的10*10矩形范围内的样本,可修改
value = data
value = value.flatten()
pt.append(value)
temp = []
pt = array_change(pt, temp)
values.append(pt)
temp2 = []
all_values = array_change(values, temp2)
all_values = np.asarray(all_values)
temp3 = []
result_values = array_change2(all_values, temp3)
result_values = np.asarray(result_values)
return result_values
def svmDeal(classArray, img_arr, outPath, im_proj, im_geotrans):
array_num = len(classArray)
classArray = np.asarray(classArray)
RGB_arr = classArray[0]
for k in range(array_num-1):
RGB_arr = np.concatenate((RGB_arr,classArray[k+1]),axis=0)
label= np.array([])
for h in range(array_num):
array_l = classArray[h].shape[0]
label = np.append(label,h*np.ones(array_l))
img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])
# svc = SVC(kernel='poly', degree=4, cache_size=1000, max_iter=100)
svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000)
svc.fit(RGB_arr,label)
predict = svc.predict(img_reshape)
for j in range(array_num):
lake_bool = predict == np.float(j)
lake_bool = lake_bool[:,np.newaxis]
lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool,lake_bool),axis=1) #四个波段
lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
img_arr[lake_bool_4d] = np.float(j)
img_arr = img_arr.transpose((2,1,0))
img_arr = img_arr[0] #只要单波段的结果
write_img(outPath, im_proj, im_geotrans, img_arr)
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_proj,im_geotrans,im_width, im_height,im_data
def write_img(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def array_change(inlist, outlist):
for i in range(len(inlist[0])):
outlist.append([j[i] for j in inlist])
return outlist
def array_change2(inlist, outlist):
for ele in inlist:
for ele2 in ele:
outlist.append(ele2)
return outlist
if __name__ == '__main__':
img_p = 'E:/1/1/data2/cgnr/0/cgnr_0.tif' #原始影像路径
shp_path = 'E:/1/1/data2/cgnr/0/point2/' #点文件路径,类似于0.shp (对应第一类)、1.shp(对应第二类)、2.shp(对应第三类)等,最终结果是和这些样本点对应的。如果不想用矢量文件就更简单了,可以直接在输入的地方放入自己的样本就行,这里主要是为了用在遥感上才这样的
class_list = []
for shp in os.listdir(shp_path):
if shp[-4:] == '.shp':
shp_full_path = os.path.join(shp_path, shp)
class_type = getPixels(shp_full_path, img_p)
class_list.append(class_type)
time1 = time.time()
im_proj, im_geotrans, im_width, im_height, im_data = read_img(img_p)
im_data = im_data.transpose((2,1,0))
out_path = 'E:/abg_test/1/data2/cgnr/0/cgnr_0_sd.tif' #输出结果
svmDeal(class_list, im_data, out_path, im_proj, im_geotrans)
time2 = time.time()
上面代码有人调不通,不行就试试下面这个:
# -*- coding: utf-8 -*- from osgeo import ogr from osgeo import gdal from gdalconst import * import os, sys, time import numpy as np from sklearn.svm import SVC def getPixels(shp, img): driver = ogr.GetDriverByName('ESRI Shapefile') ds = driver.Open(shp, 0) if ds is None: print('Could not open ' + shp) sys.exit(1) layer = ds.GetLayer() xValues = [] yValues = [] feature = layer.GetNextFeature() while feature: geometry = feature.GetGeometryRef() x = geometry.GetX() y = geometry.GetY() xValues.append(x) yValues.append(y) feature = layer.GetNextFeature() gdal.AllRegister() ds = gdal.Open(img, GA_ReadOnly) if ds is None: print('Could not open image') sys.exit(1) rows = ds.RasterYSize cols = ds.RasterXSize bands = ds.RasterCount transform = ds.GetGeoTransform() xOrigin = transform[0] yOrigin = transform[3] pixelWidth = transform[1] pixelHeight = transform[5] values = [] for i in range(len(xValues)): x = xValues[i] y = yValues[i] xOffset = int((x - xOrigin) / pixelWidth) yOffset = int((y - yOrigin) / pixelHeight) s = str(int(x)) + ' ' + str(int(y)) + ' ' + str(xOffset) + ' ' + str(yOffset) + ' ' pt = [] for j in range(bands): band = ds.GetRasterBand(j + 1) data = band.ReadAsArray(xOffset-5, yOffset-5, 10, 10) value = data value = value.flatten() pt.append(value) temp = [] pt = array_change(pt, temp) values.append(pt) temp2 = [] all_values = array_change(values, temp2) all_values = np.asarray(all_values) temp3 = [] result_values = array_change2(all_values, temp3) result_values = np.asarray(result_values) return result_values def svmDeal(classArray, img_arr, outPath, im_proj, im_geotrans): array_num = len(classArray) classArray = np.asarray(classArray) # array_l = classArray[0].shape[0] RGB_arr = classArray[0] for k in range(array_num-1): RGB_arr = np.concatenate((RGB_arr,classArray[k+1]),axis=0) label= np.array([]) for h in range(array_num): array_l = classArray[h].shape[0] label = np.append(label,h*np.ones(array_l)) img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]]) # svc = SVC(kernel='poly', degree=4, cache_size=1000, max_iter=100) svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000) svc.fit(RGB_arr,label) predict = svc.predict(img_reshape) for j in range(array_num): lake_bool = predict == np.float(j) lake_bool = lake_bool[:,np.newaxis] lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool,lake_bool),axis=1) lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2])) img_arr[lake_bool_4d] = np.float(j) img_arr = img_arr.transpose((2,1,0)) img_arr = img_arr[0] write_img(outPath, im_proj, im_geotrans, img_arr) def read_img(filename): dataset=gdal.Open(filename) im_width = dataset.RasterXSize im_height = dataset.RasterYSize im_geotrans = dataset.GetGeoTransform() im_proj = dataset.GetProjection() im_data = dataset.ReadAsArray(0,0,im_width,im_height) del dataset return im_proj,im_geotrans,im_width, im_height,im_data def write_img(filename, im_proj, im_geotrans, im_data): if 'int8' in im_data.dtype.name: datatype = gdal.GDT_Byte elif 'int16' in im_data.dtype.name: datatype = gdal.GDT_UInt16 else: datatype = gdal.GDT_Float32 if len(im_data.shape) == 3: im_bands, im_height, im_width = im_data.shape else: im_bands, (im_height, im_width) = 1,im_data.shape driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) dataset.SetGeoTransform(im_geotrans) dataset.SetProjection(im_proj) if im_bands == 1: dataset.GetRasterBand(1).WriteArray(im_data) else: for i in range(im_bands): dataset.GetRasterBand(i+1).WriteArray(im_data[i]) del dataset def array_change(inlist, outlist): for i in range(len(inlist[0])): outlist.append([j[i] for j in inlist]) return outlist def array_change2(inlist, outlist): for ele in inlist: for ele2 in ele: outlist.append(ele2) return outlist if __name__ == '__main__': img_p = 'C:/Users/DELL/Desktop/data/g1_test.tif' shp_path = 'C:/Users/DELL/Desktop/data/point/' class_list = [] for shp in os.listdir(shp_path): if shp[-4:] == '.shp': shp_full_path = os.path.join(shp_path, shp) class_type = getPixels(shp_full_path, img_p) class_list.append(class_type) time1 = time.time() im_proj, im_geotrans, im_width, im_height, im_data = read_img(img_p) im_data = im_data.transpose((2,1,0)) out_path = 'C:/Users/DELL/Desktop/data/11.tif' svmDeal(class_list, im_data, out_path, im_proj, im_geotrans) time2 = time.time() print((time2-time1)/3600)