下面的代码是python gdal 利用矢量文件在栅格图上裁剪出对应区域的实现,首先说明裁剪出来的大小肯定是完全一致的,但是我仔细看了下,貌似有点偏移,这点偏移应该是无法避免的。
代码有两版,第一版是完全把图读进去了,问题就是图像太大就没办法,后来想gdal本来就可以不读出图只获取部分图像,然后改了下,最后得到第二版,完美。
# -*- coding: utf-8 -*-
import os
import numpy as np
import gdal
from osgeo import gdal, gdalnumeric, ogr, osr, gdal_array
gdal.UseExceptions()
def world2Pixel(geoMatrix, x, y):
"""
Uses a gdal geomatrix (gdal.GetGeoTransform()) to calculate
the pixel location of a geospatial coordinate
"""
ulX = geoMatrix[0]
ulY = geoMatrix[3]
xDist = geoMatrix[1]
yDist = geoMatrix[5]
rtnX = geoMatrix[2]
rtnY = geoMatrix[4]
pixel = int((x - ulX) / xDist)
line = int((ulY - y) / xDist)
return (pixel, line)
#
# EDIT: this is basically an overloaded
# version of the gdal_array.OpenArray passing in xoff, yoff explicitly
# so we can pass these params off to CopyDatasetInfo
#
def OpenArray( array, prototype_ds = None, xoff=0, yoff=0 ):
# ds = gdal.Open( gdalnumeric.GetArrayFilename(array))
ds = gdal_array.OpenArray(array)
if ds is not None and prototype_ds is not None:
if type(prototype_ds).__name__ == 'str':
prototype_ds = gdal.Open( prototype_ds )
if prototype_ds is not None:
gdalnumeric.CopyDatasetInfo( prototype_ds, ds, xoff=xoff, yoff=yoff )
return ds
def write_img(filename,im_proj,im_geotrans,im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def shpClipRasterSingle(shapefile_path, raster_path, save_path):
# Load the source data as a gdalnumeric array
srcArray = gdalnumeric.LoadFile(raster_path) #这里读图了,不妥
# Also load as a gdal image to get geotransform
# (world file) info
srcImage = gdal.Open(raster_path)
geoTrans = srcImage.GetGeoTransform()
geoProj = srcImage.GetProjection()
# Create an OGR layer from a boundary shapefile
shapef = ogr.Open(shapefile_path)
lyr = shapef.GetLayer( os.path.split( os.path.splitext( shapefile_path )[0] )[1] )
poly = lyr.GetNextFeature()
# Convert the layer extent to image pixel coordinates
minX, maxX, minY, maxY = lyr.GetExtent()
ulX, ulY = world2Pixel(geoTrans, minX, maxY)
lrX, lrY = world2Pixel(geoTrans, maxX, minY)
# Calculate the pixel size of the new image
pxWidth = int(lrX - ulX)
pxHeight = int(lrY - ulY)
clip = srcArray[:, ulY:lrY, ulX:lrX]
#
# EDIT: create pixel offset to pass to new image Projection info
#
xoffset = ulX
yoffset = ulY
print "Xoffset, Yoffset = ( %f, %f )" % ( xoffset, yoffset )
# Create a new geomatrix for the image
geoTrans = list(geoTrans)
geoTrans[0] = minX
geoTrans[3] = maxY
write_img(save_path, geoProj, geoTrans, clip)
gdal.ErrorReset()
def shpClipRaster(shapefile_path, srcArray, srcImage, save_path):
# Load the source data as a gdalnumeric array
# srcArray = gdalnumeric.LoadFile(raster_path)
# Also load as a gdal image to get geotransform
# (world file) info
# srcImage = gdal.Open(raster_path)
geoTrans = srcImage.GetGeoTransform()
geoProj = srcImage.GetProjection()
# Create an OGR layer from a boundary shapefile
shapef = ogr.Open(shapefile_path)
lyr = shapef.GetLayer( os.path.split( os.path.splitext( shapefile_path )[0] )[1] )
poly = lyr.GetNextFeature()
# Convert the layer extent to image pixel coordinates
minX, maxX, minY, maxY = lyr.GetExtent()
ulX, ulY = world2Pixel(geoTrans, minX, maxY)
lrX, lrY = world2Pixel(geoTrans, maxX, minY)
# Calculate the pixel size of the new image
pxWidth = int(lrX - ulX)
pxHeight = int(lrY - ulY)
clip = srcArray[:, ulY:lrY, ulX:lrX]
#
# EDIT: create pixel offset to pass to new image Projection info
#
xoffset = ulX
yoffset = ulY
print "Xoffset, Yoffset = ( %f, %f )" % ( xoffset, yoffset )
# Create a new geomatrix for the image
geoTrans = list(geoTrans)
geoTrans[0] = minX
geoTrans[3] = maxY
write_img(save_path, geoProj, geoTrans, clip)
gdal.ErrorReset()
if __name__ == "__main__":
config_file='clip_config.txt' #这里是我要批量裁剪写的,忽略吧,函数就是上面那个shpClipRasterSingle和shpClipRaster,前者是在函数里读图,批量的时候每次都读图很不合理,后者是在外面读图。
dirs=[]
for line in open(config_file):
dirs.append(line.split()[0])
img = dirs[0]
img = img.replace('\\','/')
image_path = dirs[1]
image_path = image_path.replace('\\','/')
save_path = dirs[2]
save_path = save_path.replace('\\','/')
srcArray = gdalnumeric.LoadFile(img)
srcImage = gdal.Open(img)
files = os.listdir(image_path)
for f in files:
f_path = os.path.join(image_path, f)
file_list=os.listdir(f_path)
for ele in file_list:
if ele[-4:] == '.tif':
name = ele.split('.')[0]
for ele2 in file_list:
if ele2[-4:] == '.shp':
save_img = os.path.join(save_path, name + '.tif')
shpClipRaster(ele2, srcArray, srcImage, save_img)
改进:
# -*- coding: utf-8 -*-
import os
import numpy as np
import gdal
from osgeo import gdal, gdalnumeric, ogr, osr, gdal_array
gdal.UseExceptions()
def world2Pixel(geoMatrix, x, y):
"""
Uses a gdal geomatrix (gdal.GetGeoTransform()) to calculate
the pixel location of a geospatial coordinate
"""
ulX = geoMatrix[0]
ulY = geoMatrix[3]
xDist = geoMatrix[1]
yDist = geoMatrix[5]
rtnX = geoMatrix[2]
rtnY = geoMatrix[4]
pixel = int((x - ulX) / xDist)
line = int((ulY - y) / xDist)
return (pixel, line)
#
# EDIT: this is basically an overloaded
# version of the gdal_array.OpenArray passing in xoff, yoff explicitly
# so we can pass these params off to CopyDatasetInfo
#
def OpenArray( array, prototype_ds = None, xoff=0, yoff=0 ):
# ds = gdal.Open( gdalnumeric.GetArrayFilename(array))
ds = gdal_array.OpenArray(array)
if ds is not None and prototype_ds is not None:
if type(prototype_ds).__name__ == 'str':
prototype_ds = gdal.Open( prototype_ds )
if prototype_ds is not None:
gdalnumeric.CopyDatasetInfo( prototype_ds, ds, xoff=xoff, yoff=yoff )
return ds
def write_img(filename,im_proj,im_geotrans,im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def shpClipRaster(shapefile_path, raster_path, save_path):
# Load the source data as a gdalnumeric array
# srcArray = gdalnumeric.LoadFile(raster_path)
# Also load as a gdal image to get geotransform
# (world file) info
srcImage = gdal.Open(raster_path)
geoTrans = srcImage.GetGeoTransform()
geoProj = srcImage.GetProjection()
# Create an OGR layer from a boundary shapefile
shapef = ogr.Open(shapefile_path)
lyr = shapef.GetLayer( os.path.split( os.path.splitext( shapefile_path )[0] )[1] )
poly = lyr.GetNextFeature()
# Convert the layer extent to image pixel coordinates
minX, maxX, minY, maxY = lyr.GetExtent()
ulX, ulY = world2Pixel(geoTrans, minX, maxY)
lrX, lrY = world2Pixel(geoTrans, maxX, minY)
# Calculate the pixel size of the new image
pxWidth = int(lrX - ulX)
pxHeight = int(lrY - ulY)
# clip = srcArray[:, ulY:lrY, ulX:lrX]
clip = srcImage.ReadAsArray(ulX,ulY,pxWidth,pxHeight) #***只读要的那块***
#
# EDIT: create pixel offset to pass to new image Projection info
#
xoffset = ulX
yoffset = ulY
print "Xoffset, Yoffset = ( %f, %f )" % ( xoffset, yoffset )
# Create a new geomatrix for the image
geoTrans = list(geoTrans)
geoTrans[0] = minX
geoTrans[3] = maxY
write_img(save_path, geoProj, geoTrans, clip)
gdal.ErrorReset()
if __name__ == "__main__":
shp = "D:/wj_sample/e/1/1.shp"
img = "D:/tool/test_seg/suzhou_wujiangqu1.tif"
out = "D:/tool/test_seg/temp.tif"
shpClipRaster(shp,img,out)