# -*- coding: utf-8 -*-
import os
import numpy as np
import gdal
import re
import time
#import arcpy
imagefile="D:/0220/605611/1/mask.tif" #参考图像,里面有坐标系需要用到
line_file = "D:/0220/605611/1/t.shp" #需要转的矢量数据
outraster_file = "D:/0220/605611/1/road_temp.tif" #存储的结果
# fields1 = [f1.name for f1 in arcpy.ListFields(line_file)]
field_name='FID_1'
# field_name='FID_1'
# if 'FID_1' in fields1:
# arcpy.CalculateField_management(line_file,'FID_1',1,"PYTHON_9.3")
# if 'FID_1' not in fields1:
# arcpy.AddField_management(line_file,'FID_1',"SHORT",'','',10)
# arcpy.CalculateField_management(line_file,'FID_1',1,"PYTHON_9.3")
image=gdal.Open(imagefile)
width=image.RasterXSize
height=image.RasterYSize
geotransform=image.GetGeoTransform()
ref=image.GetProjection()
xmin=geotransform[0]
xmax=xmin+geotransform[1]*image.RasterXSize
ymax=geotransform[3]
ymin=ymax+geotransform[5]*image.RasterYSize
opt={'format':"GTiff",'width':image.RasterXSize,'height':image.RasterYSize,'initValues':0,'attribute':field_name}
gdal.Rasterize(outraster_file,line_file,**opt)
print("success")
img = gdal.Open(outraster_file)
data=img.GetRasterBand(1).ReadAsArray().astype(np.uint8)
data[np.where(data>0)]=255
driver = gdal.GetDriverByName("GTiff")
ds = driver.Create(outraster_file,image.RasterXSize,image.RasterYSize,1,gdal.GDT_Byte)
ds.SetGeoTransform(geotransform)
ds.SetProjection(ref)
ds.GetRasterBand(1).WriteArray(data)
ds=None
print outraster_file + ' success'
注意:注释的部分是在你安装了arcgis并且配置了arcpy的情况下用的,如果你没有arcgis想用纯的gdal实现这个那么你需要在你的矢量文件里(上面的line_file)创建一个字段并且字段都赋值为1就行了,或者自己去搜索一下怎么用gdal穿件一个地段。
补充纯gdal实现上面的过程:
# -*- coding: utf-8 -*-
import os
import numpy as np
import gdal
from osgeo import ogr
import re
import time
imagefile="D:/0220/605611/1/mask.tif"
line_file = "D:/0220/605611/1/t.shp"
outraster_file = "D:/0220/605611/1/road_temp.tif"
field_name="FieldID"
ds = ogr.Open(line_file, 1)
oLayer = ds.GetLayerByIndex(0)
print(oLayer.GetGeomType()) #矢量类型,输出的是1,2,3(点、线、面)
oDefn = oLayer.GetLayerDefn()
#这里可以不管,这部分是获取字段的名称,遍历名称以及他们的属性的
iFieldCount = oDefn.GetFieldCount()
for iAttr in range(iFieldCount):
oField =oDefn.GetFieldDefn(iAttr)
# print(oField.GetNameRef(),
# oField.GetFieldTypeName(oField.GetType()),
# oField.GetWidth(),
# oField.GetPrecision())
#这里是创建字段的过程,这里是一个一个创建的,直接全部创建的没时间找,你们看到了麻烦评论告诉我下
oFieldID =ogr.FieldDefn(field_name, ogr.OFTInteger)
oLayer.CreateField(oFieldID, 1)
fieldIndex0 = oDefn.GetFieldIndex(field_name)
for i in range(0,oLayer.GetFeatureCount()):
feature = oLayer.GetFeature(i)
value= 1 #创建的每个值都等于1
feature.SetField(fieldIndex0,value)
oLayer.SetFeature(feature)
feature.Destroy()
ds.Destroy()
image=gdal.Open(imagefile)
width=image.RasterXSize
height=image.RasterYSize
geotransform=image.GetGeoTransform()
ref=image.GetProjection()
xmin=geotransform[0]
xmax=xmin+geotransform[1]*image.RasterXSize
ymax=geotransform[3]
ymin=ymax+geotransform[5]*image.RasterYSize
opt={'format':"GTiff",'width':image.RasterXSize,'height':image.RasterYSize,'initValues':0,'attribute':field_name}
gdal.Rasterize(outraster_file,line_file,**opt)
print("success")
img = gdal.Open(outraster_file)
data=img.GetRasterBand(1).ReadAsArray().astype(np.uint8)
data[np.where(data>0)]=255
driver = gdal.GetDriverByName("GTiff")
ds = driver.Create(outraster_file,image.RasterXSize,image.RasterYSize,1,gdal.GDT_Byte)
ds.SetGeoTransform(geotransform)
ds.SetProjection(ref)
ds.GetRasterBand(1).WriteArray(data)
ds=None
print outraster_file + ' success'
上面有偏移,同事测试了下面这个没有,记录一下
import gdal
from gdal import gdalconst
from osgeo import ogr
import numpy as np
import os
import glob
def vector2raster(inputfilePath, outputfile, templatefile):
inputfilePath = inputfilePath
outputfile = outputfile
templatefile = templatefile
data = gdal.Open(templatefile, gdalconst.GA_ReadOnly)
x_res = data.RasterXSize
y_res = data.RasterYSize
vector = ogr.Open(inputfilePath)
layer = vector.GetLayer()
targetDataset = gdal.GetDriverByName('GTiff').Create(outputfile, x_res, y_res, 3, gdal.GDT_Byte)
targetDataset.SetGeoTransform(data.GetGeoTransform())
targetDataset.SetProjection(data.GetProjection())
band = targetDataset.GetRasterBand(1)
NoData_value = -999
band.SetNoDataValue(NoData_value)
band.FlushCache()
gdal.RasterizeLayer(targetDataset, [1, 2, 3], layer, )
def raster2vector():
inputfile = 'vector.tif'
ds = gdal.Open(inputfile, gdal.GA_ReadOnly)
srcband = ds.GetRasterBand(1)
maskband = srcband.GetMaskBand()
outfile = 'raster.shp'
drv = ogr.GetDriverByName('ESRI Shapefile')
dst_ds = drv.CreateDataSource(outfile)
srs = None
dst_layer = dst_ds.CreateLayer(outfile, srs=srs)
dst_fieldname = 'DN'
fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
dst_layer.CreateField(fd)
dst_field = 0
options = []
gdal.Polygonize(srcband, maskband, dst_layer, dst_field, options)
if __name__=="__main__":
imgList = ['old_area.tif']
shapeList = ['water.shp']
for idx, (img, shp) in enumerate(zip(imgList, shapeList)):
baseDir, name = os.path.split(img)[0], os.path.split(img)[-1]
outputfile = os.path.join(baseDir, name.replace('.', '_mask.'))
vector2raster(shp, outputfile, img)
# raster2vector()