python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化

根据我前述博客中对图像传分割算法及图像块合并方法的实验探究,在此将这些方法用于遥感影像并尝试矢量化。
这个过程中我自己遇到了一个棘手的问题,在最后的结果那里有描述,希望知道的朋友帮忙解答一下,谢谢!
直接上代码:

# -*- coding: utf-8 -*-
import os
import cv2
import gdal
from osgeo import ogr,osr
import numpy as np
from skimage import morphology, color, measure
from skimage.segmentation import felzenszwalb, slic, quickshift
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float
from skimage.future import graph

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_width,im_height,im_proj,im_geotrans,im_data

def write_img(filename,im_proj,im_geotrans,im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)
    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

def DoesDriverHandleExtension(drv, ext):
    exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
    return exts is not None and exts.lower().find(ext.lower()) >= 0


def GetExtension(filename):
    ext = os.path.splitext(filename)[1]
    if ext.startswith('.'):
        ext = ext[1:]
    return ext


def GetOutputDriversFor(filename):
    drv_list = []
    ext = GetExtension(filename)
    for i in range(gdal.GetDriverCount()):
        drv = gdal.GetDriver(i)
        if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or
            drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \
           drv.GetMetadataItem(gdal.DCAP_VECTOR) is not None:
            if ext and DoesDriverHandleExtension(drv, ext):
                drv_list.append(drv.ShortName)
            else:
                prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX)
                if prefix is not None and filename.lower().startswith(prefix.lower()):
                    drv_list.append(drv.ShortName)

    return drv_list

def GetOutputDriverFor(filename):
    drv_list = GetOutputDriversFor(filename)
    ext = GetExtension(filename)
    if not drv_list:
        if not ext:
            return 'ESRI Shapefile'
        else:
            raise Exception("Cannot guess driver for %s" % filename)
    elif len(drv_list) > 1:
        print("Several drivers matching %s extension. Using %s" % (ext if ext else '', drv_list[0]))
    return drv_list[0]

def _weight_mean_color(graph, src, dst, n):
    """Callback to handle merging nodes by recomputing mean color.
    The method expects that the mean color of `dst` is already computed.
    Parameters
    ----------
    graph : RAG
        The graph under consideration.
    src, dst : int
        The vertices in `graph` to be merged.
    n : int
        A neighbor of `src` or `dst` or both.

    Returns
    -------
    data : dict
        A dictionary with the `"weight"` attribute set as the absolute
        difference of the mean color between node `dst` and `n`.
    """
    diff = graph.nodes[dst]['mean color'] - graph.nodes[n]['mean color']
    diff = np.linalg.norm(diff)
    return {'weight': diff}

def merge_mean_color(graph, src, dst):
    """Callback called before merging two nodes of a mean color distance graph.
    This method computes the mean color of `dst`.
    Parameters
    ----------
    graph : RAG
        The graph under consideration.
    src, dst : int
        The vertices in `graph` to be merged.
    """
    graph.nodes[dst]['total color'] += graph.nodes[src]['total color']
    graph.nodes[dst]['pixel count'] += graph.nodes[src]['pixel count']
    graph.nodes[dst]['mean color'] = (graph.nodes[dst]['total color'] /
                                      graph.nodes[dst]['pixel count'])

if __name__ == '__main__':
    img_path = "E:/geo_test/test.tif"
    temp_path = "E:/geo_test/temp/"
    im_width,im_height,im_proj,im_geotrans,im_data = read_img(img_path)  
    temp = im_data.transpose((2,1,0))
    segments_quick = quickshift(temp, kernel_size=3, max_dist=6, ratio=0.5)
    
    mark0 = mark_boundaries(temp, segments_quick)
    save_path = temp_path + "qs_seg0.tif"
    re0 = mark0.transpose((2,1,0))
    write_img(save_path,im_proj,im_geotrans,re0)

    grid_path = temp_path + "qs_grid0.tif"
    grid0 = np.uint8(re0[0,...])
    write_img(grid_path,im_proj,im_geotrans,grid0)

    skeleton = morphology.skeletonize(grid0)
    border0 = np.multiply(grid0, skeleton)
    ret,border0 = cv2.threshold(border0,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    border_path = temp_path + "qs_border0.tif"
    write_img(border_path,im_proj,im_geotrans,border0)
    
    g = graph.rag_mean_color(temp, segments_quick)
    labels2 = graph.merge_hierarchical(segments_quick, g, thresh=5, 
		      rag_copy=False,
              in_place_merge=True,
              merge_func=merge_mean_color,
              weight_func=_weight_mean_color)
    label_rgb2 = color.label2rgb(labels2, temp, kind='avg')
    rgb_path = temp_path + "qs_label.tif"
    lb = labels2.transpose((1,0))
    # lb = median(lb, disk(3))
    write_img(rgb_path,im_proj,im_geotrans,lb)
    
    mark = mark_boundaries(label_rgb2, labels2)
    save_path = temp_path + "qs_seg.tif"
    re = mark.transpose((2,1,0))
    write_img(save_path,im_proj,im_geotrans,re)

    grid_path = temp_path + "qs_grid.tif"
    grid = np.uint8(re[0,...])
    write_img(grid_path,im_proj,im_geotrans,grid)

    skeleton = morphology.skeletonize(grid)
    border = np.multiply(grid, skeleton)
    ret,border = cv2.threshold(border,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    border_path = temp_path + "qs_border.tif"
    write_img(border_path,im_proj,im_geotrans,border)

    # out_shp = temp_path + "temp.shp"
    # RasterToLineshp(border_path, out_shp, 1)

    border_driver = gdal.Open(rgb_path)
    border_band = border_driver.GetRasterBand(1)
    border_mask = border_band.GetMaskBand()

    dst_filename = temp_path + 'temp.shp'
    frmt = GetOutputDriverFor(dst_filename)
    drv = ogr.GetDriverByName(frmt)
    dst_ds = drv.CreateDataSource(dst_filename)
    
    dst_layername = 'out'
    srs = osr.SpatialReference(wkt=border_driver.GetProjection())
    dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbPolygon, srs=srs)
    # dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbLineString, srs=srs)


    dst_fieldname = 'DN'
    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
    dst_layer.CreateField(fd)
    dst_field = 0

    options = [""]
    options.append('DATASET_FOR_GEOREF=' + rgb_path)
    prog_func = gdal.TermProgress_nocb
    gdal.Polygonize(border_band, border_mask, dst_layer, dst_field, options,
                         callback=prog_func)

    srcband = None
    src_ds = None
    dst_ds = None
    mask_ds = None

# enum WKBGeometryType {
# wkbPoint = 1,
# wkbLineString = 2,
# wkbPolygon = 3,
# wkbTriangle = 17
# wkbMultiPoint = 4,
# wkbMultiLineString = 5,
# wkbMultiPolygon = 6,
# wkbGeometryCollection = 7,
# wkbPolyhedralSurface = 15,
# wkbTIN = 16
# wkbPointZ = 1001,
# wkbLineStringZ = 1002,
# wkbPolygonZ = 1003,
# wkbTrianglez = 1017
# wkbMultiPointZ = 1004,
# wkbMultiLineStringZ = 1005,
# wkbMultiPolygonZ = 1006,
# wkbGeometryCollectionZ = 1007,
# wkbPolyhedralSurfaceZ = 1015,
# wkbTINZ = 1016
# wkbPointM = 2001,
# wkbLineStringM = 2002,
# wkbPolygonM = 2003,
# wkbTriangleM = 2017
# wkbMultiPointM = 2004,
# wkbMultiLineStringM = 2005,
# wkbMultiPolygonM = 2006,
# wkbGeometryCollectionM = 2007,
# wkbPolyhedralSurfaceM = 2015,
# wkbTINM = 2016
# wkbPointZM = 3001,
# wkbLineStringZM = 3002,
# wkbPolygonZM = 3003,
# wkbTriangleZM = 3017
# wkbMultiPointZM = 3004,
# wkbMultiLineStringZM = 3005,
# wkbMultiPolygonZM = 3006,
# wkbGeometryCollectionZM = 3007,
# wkbPolyhedralSurfaceZM = 3015,
# wkbTinZM = 3016,
# }

对应的结果图如下:
原图:
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化_第1张图片
粗分割结果(代码中的qs_seg0.tif)
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化_第2张图片
粗分割格网(代码中的qs_grid0.tif)
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化_第3张图片
粗分割格网骨架(代码中的qs_border0.tif),格网的结果不是单线的,这里取了中心线。
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化_第4张图片
合并后的分割结果(代码中的qs_seg.tif):
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化_第5张图片
合并后的格网结果(代码中的qs_grid.tif)
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化_第6张图片
合并后的格网骨架结果(代码中的qs_border.tif):
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化_第7张图片
下面是矢量化以后的最终结果,这是代码中的qs_label.tif经过矢量化以后得到的结果,这里说明一下,之所以不用栅格线来直接转矢量线是因为我在GDAL里面并没有找到直接转化的方法,目前的方法强行转的话只能得到双线,完全不对,找了很久也没找到解决办法只能折中一下先得到面了,后面再面转线,看到的朋友如果知道的话烦请告知一下用什么办法可以直接把栅格线转为矢量线,要求脱离arcgis哈。
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化_第8张图片

TO DO:
1.矢量面转线
2.线简化
3.线平滑
做完更新,感兴趣的朋友可以关注一下。

后续:
目前矢量面转矢量线肯定是没问题的,但是有个大问题就是矢量线的平滑对我来说还有一定难度,想不到具体高效的方式,唯一想到的方式就是将图层里的每一个节点找到,在节点位置不变的情况下取出节点之间的线条逐个平滑再放回到图层中,这样做有点慢,并且实现起来也比较复杂感觉,所以再次折中,我直接进行面的平滑,平滑完了再转线看看有没有可能对结果有帮助。
虽然不做线平滑了,下面还是先给出面专线的代码:

# -*- coding: utf-8 -*-
import os
import gdal
from osgeo import ogr,osr
import numpy as np

def Test_Poly2Line(input_poly,output_line):
    ogr.RegisterAll()
    
    driver = ogr.GetDriverByName('ESRI Shapefile')
    source_ds = driver.Open(input_poly,1)   
    source_layer = source_ds.GetLayer(0)

    # polygon2geometryCollection
    geomcol =  ogr.Geometry(ogr.wkbGeometryCollection)
    for feat in source_layer:
        geom = feat.GetGeometryRef()
        ring = geom.GetGeometryRef(0)
        geomcol.AddGeometry(ring)
        
    # geometryCollection2shp
    shpDriver = ogr.GetDriverByName("ESRI Shapefile")
    if os.path.exists(output_line):
            shpDriver.DeleteDataSource(output_line)
    outDataSource = shpDriver.CreateDataSource(output_line)
    outLayer = outDataSource.CreateLayer(output_line, geom_type=ogr.wkbMultiLineString)
    featureDefn = outLayer.GetLayerDefn()
    outFeature = ogr.Feature(featureDefn)
    outFeature.SetGeometry(geomcol)
    outLayer.CreateFeature(outFeature)
    outFeature = None


if __name__ == "__main__":
    poly_path = "E:/geo_test/temp/temp.shp"
    line_path = "E:/geo_test/temp/temp2line.shp"
    Test_Poly2Line(poly_path, line_path)

结果如下,可以看到这个结果和面完全保持一致,毕竟是gdal源码哈哈。
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化_第9张图片

下面说一下在面未转为线的时候就平滑,在下面的位置加入了中值滤波
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化_第10张图片
这是栅格面平滑后转化为面矢量的结果
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化_第11张图片
这是和之前没有进行平滑的结果的叠加对比,变化是有的,但是这里有一个大问题,就是锯齿状太严重。
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化_第12张图片

你可能感兴趣的:(gdal,python)