python gdal矢量转栅格

# -*- coding: utf-8 -*-
import os
import numpy as np
import gdal
import re
import time
#import arcpy

imagefile="D:/0220/605611/1/mask.tif"  #参考图像,里面有坐标系需要用到
line_file = "D:/0220/605611/1/t.shp"  #需要转的矢量数据
outraster_file = "D:/0220/605611/1/road_temp.tif"  #存储的结果

# fields1 = [f1.name for f1 in arcpy.ListFields(line_file)]
field_name='FID_1'
# field_name='FID_1'
# if 'FID_1' in fields1:
#     arcpy.CalculateField_management(line_file,'FID_1',1,"PYTHON_9.3")
# if 'FID_1' not in fields1:
#     arcpy.AddField_management(line_file,'FID_1',"SHORT",'','',10)
#     arcpy.CalculateField_management(line_file,'FID_1',1,"PYTHON_9.3")

image=gdal.Open(imagefile)
width=image.RasterXSize
height=image.RasterYSize
geotransform=image.GetGeoTransform()             
ref=image.GetProjection()
xmin=geotransform[0]
xmax=xmin+geotransform[1]*image.RasterXSize
ymax=geotransform[3]
ymin=ymax+geotransform[5]*image.RasterYSize
opt={'format':"GTiff",'width':image.RasterXSize,'height':image.RasterYSize,'initValues':0,'attribute':field_name}
gdal.Rasterize(outraster_file,line_file,**opt)
print("success")
img = gdal.Open(outraster_file)
data=img.GetRasterBand(1).ReadAsArray().astype(np.uint8)

data[np.where(data>0)]=255

driver = gdal.GetDriverByName("GTiff")
ds = driver.Create(outraster_file,image.RasterXSize,image.RasterYSize,1,gdal.GDT_Byte) 
ds.SetGeoTransform(geotransform)
ds.SetProjection(ref)
ds.GetRasterBand(1).WriteArray(data)
ds=None
print outraster_file + '  success'

注意:注释的部分是在你安装了arcgis并且配置了arcpy的情况下用的,如果你没有arcgis想用纯的gdal实现这个那么你需要在你的矢量文件里(上面的line_file)创建一个字段并且字段都赋值为1就行了,或者自己去搜索一下怎么用gdal穿件一个地段。

补充纯gdal实现上面的过程:

# -*- coding: utf-8 -*-
import os
import numpy as np
import gdal
from osgeo import ogr
import re
import time

imagefile="D:/0220/605611/1/mask.tif"
line_file = "D:/0220/605611/1/t.shp"
outraster_file = "D:/0220/605611/1/road_temp.tif"

field_name="FieldID"
ds = ogr.Open(line_file, 1)
oLayer = ds.GetLayerByIndex(0)
print(oLayer.GetGeomType())  #矢量类型,输出的是1,2,3(点、线、面)
oDefn = oLayer.GetLayerDefn()

#这里可以不管,这部分是获取字段的名称,遍历名称以及他们的属性的
iFieldCount = oDefn.GetFieldCount()
for iAttr in range(iFieldCount):
    oField =oDefn.GetFieldDefn(iAttr)
    # print(oField.GetNameRef(),
    # oField.GetFieldTypeName(oField.GetType()),
    # oField.GetWidth(),
    # oField.GetPrecision())

#这里是创建字段的过程,这里是一个一个创建的,直接全部创建的没时间找,你们看到了麻烦评论告诉我下
oFieldID =ogr.FieldDefn(field_name, ogr.OFTInteger) 
oLayer.CreateField(oFieldID, 1)
fieldIndex0 = oDefn.GetFieldIndex(field_name)
for i in range(0,oLayer.GetFeatureCount()):
    feature = oLayer.GetFeature(i)
    value= 1  #创建的每个值都等于1
    feature.SetField(fieldIndex0,value) 
    oLayer.SetFeature(feature)
feature.Destroy()
ds.Destroy()

image=gdal.Open(imagefile)
width=image.RasterXSize
height=image.RasterYSize
geotransform=image.GetGeoTransform()             
ref=image.GetProjection()
xmin=geotransform[0]
xmax=xmin+geotransform[1]*image.RasterXSize
ymax=geotransform[3]
ymin=ymax+geotransform[5]*image.RasterYSize
opt={'format':"GTiff",'width':image.RasterXSize,'height':image.RasterYSize,'initValues':0,'attribute':field_name}
gdal.Rasterize(outraster_file,line_file,**opt)
print("success")
img = gdal.Open(outraster_file)
data=img.GetRasterBand(1).ReadAsArray().astype(np.uint8)

data[np.where(data>0)]=255

driver = gdal.GetDriverByName("GTiff")
ds = driver.Create(outraster_file,image.RasterXSize,image.RasterYSize,1,gdal.GDT_Byte) 
ds.SetGeoTransform(geotransform)
ds.SetProjection(ref)
ds.GetRasterBand(1).WriteArray(data)
ds=None
print outraster_file + '  success'

上面有偏移,同事测试了下面这个没有,记录一下

import gdal
from gdal import gdalconst
from osgeo import ogr
import numpy as np
import os
import glob

def vector2raster(inputfilePath, outputfile, templatefile):
    inputfilePath = inputfilePath
    outputfile = outputfile
    templatefile = templatefile
    data = gdal.Open(templatefile, gdalconst.GA_ReadOnly)
    x_res = data.RasterXSize
    y_res = data.RasterYSize
    vector = ogr.Open(inputfilePath)
    layer = vector.GetLayer()
    targetDataset = gdal.GetDriverByName('GTiff').Create(outputfile, x_res, y_res, 3, gdal.GDT_Byte)
    targetDataset.SetGeoTransform(data.GetGeoTransform())
    targetDataset.SetProjection(data.GetProjection())
    band = targetDataset.GetRasterBand(1)
    NoData_value = -999
    band.SetNoDataValue(NoData_value)
    band.FlushCache()
    gdal.RasterizeLayer(targetDataset, [1, 2, 3], layer, )


def raster2vector():
    inputfile = 'vector.tif'
    ds = gdal.Open(inputfile, gdal.GA_ReadOnly)
    srcband = ds.GetRasterBand(1)
    maskband = srcband.GetMaskBand()
    outfile = 'raster.shp'
    drv = ogr.GetDriverByName('ESRI Shapefile')
    dst_ds = drv.CreateDataSource(outfile)
    srs = None
    dst_layer = dst_ds.CreateLayer(outfile, srs=srs)
    dst_fieldname = 'DN'
    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
    dst_layer.CreateField(fd)
    dst_field = 0
    options = []
    gdal.Polygonize(srcband, maskband, dst_layer, dst_field, options)


if __name__=="__main__":
    imgList = ['old_area.tif']
    shapeList = ['water.shp']
    for idx, (img, shp) in enumerate(zip(imgList, shapeList)):
        baseDir, name = os.path.split(img)[0], os.path.split(img)[-1]
        outputfile = os.path.join(baseDir, name.replace('.', '_mask.'))
        vector2raster(shp, outputfile, img)
    # raster2vector()

你可能感兴趣的:(gdal,python)