gdal切火星偏移的瓦片

修改了gdal2tiles(非最新版),使之:

  • 支持火星偏移
  • 默认改为谷歌瓦片模式(/z/x/y),原来是/z/x/-y
  • 支持瓦片压缩
  • 需要安装gdal和pngquant(如果需要压缩的话)

如果开启压缩的话,时间会显著延长,因此默认是关闭压缩的。各平台的pngquant请自行解决

使用方式:
python3 /tif2tile.py -p $(nproc) -i result.tif -s EPSG:4326 -o /myresult

  • linux 下通过$(nproc)自动使用全部线程
# -*- coding=utf-8 -*-
import platform
from xml.etree import ElementTree
import json
from osgeo import gdal, osr
from uuid import uuid4
import sys
import shutil
import tempfile
import os
import math
from multiprocessing import Pool, Process, Manager, cpu_count
# 瓦片大小,只能是512或256
TILESIZE = 256
# 瓦片格式
TILEDRIVER = 'PNG'
TILEEXT = 'png'
# 空瓦片的大小
EMPTY = {'512': 1096, "256": 334}
QUANTFILE = None
if(platform.system() == 'Windows'):
    QUANTFILE = os.path.join(os.getcwd(), "pngquant.exe")
# 压缩
COMPRESS = False
if COMPRESS:
    import pngquant


class GlobalMercator(object):
    def __init__(self):
        self.initialResolution = 2 * math.pi * 6378137 / TILESIZE
        self.originShift = 2 * math.pi * 6378137 / 2.0

    def MetersToLatLon(self, mx, my):
        lon = (mx / self.originShift) * 180.0
        lat = (my / self.originShift) * 180.0
        lat = 180 / math.pi * \
            (2 * math.atan(math.exp(lat * math.pi / 180.0)) - math.pi / 2.0)
        # 去掉无用精度
        lat = round(lat, 6)
        lon = round(lon, 6)
        return lat, lon

    def PixelsToMeters(self, px, py, zoom):
        res = self.Resolution(zoom)
        mx = px * res - self.originShift
        my = py * res - self.originShift
        return mx, my

    def MetersToPixels(self, mx, my, zoom):
        res = self.Resolution(zoom)
        px = (mx + self.originShift) / res
        py = (my + self.originShift) / res
        return px, py

    def PixelsToTile(self, px, py):
        tx = int(math.ceil(px / float(TILESIZE)) - 1)
        ty = int(math.ceil(py / float(TILESIZE)) - 1)
        return tx, ty

    def MetersToTile(self, mx, my, zoom):
        px, py = self.MetersToPixels(mx, my, zoom)
        return self.PixelsToTile(px, py)

    def LatLonToMeters(self, lat, lon):
        mx = lon * self.originShift / 180.0
        my = math.log(math.tan((90 + lat) * math.pi / 360.0)) / \
            (math.pi / 180.0)
        my = my * self.originShift / 180.0
        return mx, my

    def TileBounds(self, tx, ty, zoom):
        minx, miny = self.PixelsToMeters(tx * TILESIZE, ty * TILESIZE, zoom)
        maxx, maxy = self.PixelsToMeters(
            (tx + 1) * TILESIZE, (ty + 1) * TILESIZE, zoom)
        # 翻转
        return (minx, miny, maxx, maxy)

    def Resolution(self, zoom):
        return self.initialResolution / (2 ** zoom)


    def WGS84ToGCJ02(self, lat, lng):
        pi = 3.1415926535897932384626
        ee = 0.00669342162296594323
        a = 6378245.0
        # 算法来源于https://github.com/wandergis/coordTransform_py

        def _transformlat(lng, lat):
            ret = -100.0 + 2.0 * lng + 3.0 * lat + 0.2 * lat * lat + \
                0.1 * lng * lat + 0.2 * math.sqrt(math.fabs(lng))
            ret += (20.0 * math.sin(6.0 * lng * pi) + 20.0 *
                    math.sin(2.0 * lng * pi)) * 2.0 / 3.0
            ret += (20.0 * math.sin(lat * pi) + 40.0 *
                    math.sin(lat / 3.0 * pi)) * 2.0 / 3.0
            ret += (160.0 * math.sin(lat / 12.0 * pi) + 320 *
                    math.sin(lat * pi / 30.0)) * 2.0 / 3.0
            return ret

        def _transformlng(lng, lat):
            ret = 300.0 + lng + 2.0 * lat + 0.1 * lng * lng + \
                0.1 * lng * lat + 0.1 * math.sqrt(math.fabs(lng))
            ret += (20.0 * math.sin(6.0 * lng * pi) + 20.0 *
                    math.sin(2.0 * lng * pi)) * 2.0 / 3.0
            ret += (20.0 * math.sin(lng * pi) + 40.0 *
                    math.sin(lng / 3.0 * pi)) * 2.0 / 3.0
            ret += (150.0 * math.sin(lng / 12.0 * pi) + 300.0 *
                    math.sin(lng / 30.0 * pi)) * 2.0 / 3.0
            return ret
        dlat = _transformlat(lng - 105.0, lat - 35.0)
        dlng = _transformlng(lng - 105.0, lat - 35.0)
        radlat = lat / 180.0 * pi
        magic = math.sin(radlat)
        magic = 1 - ee * magic * magic
        sqrtmagic = math.sqrt(magic)
        dlat = (dlat * 180.0) / ((a * (1 - ee)) / (magic * sqrtmagic) * pi)
        dlng = (dlng * 180.0) / (a / sqrtmagic * math.cos(radlat) * pi)
        mglat = lat + dlat
        mglng = lng + dlng
        mglat = round(mglat, 6)
        mglng = round(mglng, 6)
        return mglat, mglng

    def ZoomForPixelSize(self, pixelSize):
        for i in range(32):
            if pixelSize > self.Resolution(i):
                if i != -1:
                    return i-1
                else:
                    return 0


def check(status, message):
    if status:
        sys.stderr.write("运行出错: %s\n" % message)
        sys.exit(3)

def GoogleTile( zoom,ty):
    offset=0
    if TILESIZE==512:
        offset=1
    return (2**(zoom-offset) - 1) - ty

def gettempfilename(suffix):
    if '_' in os.environ:
        if os.environ['_'].find('wine') >= 0:
            tmpdir = '.'
            if 'TMP' in os.environ:
                tmpdir = os.environ['TMP']
            import time
            import random
            random.seed(time.time())
            random_part = 'file%d' % random.randint(0, 1000000000)
            return os.path.join(tmpdir, random_part + suffix)

    return tempfile.mktemp(suffix)


def add_alpha_band_to_string_vrt(vrt_string):
    vrt_root = ElementTree.fromstring(vrt_string)
    index = 0
    nb_bands = 0
    for subelem in list(vrt_root):
        if subelem.tag == "VRTRasterBand":
            nb_bands += 1
            color_node = subelem.find("./ColorInterp")
            if color_node is not None and color_node.text == "Alpha":
                raise Exception("Alpha band already present")
        else:
            if nb_bands:
                break
        index += 1

    tb = ElementTree.TreeBuilder()
    tb.start("VRTRasterBand", {'dataType': "Byte", "band": str(nb_bands + 1),
                               "subClass": "VRTWarpedRasterBand"})
    tb.start("ColorInterp", {})
    tb.data("Alpha")
    tb.end("ColorInterp")
    tb.end("VRTRasterBand")
    elem = tb.close()
    vrt_root.insert(index, elem)
    warp_options = vrt_root.find(".//GDALWarpOptions")
    tb = ElementTree.TreeBuilder()
    tb.start("DstAlphaBand", {})
    tb.data(str(nb_bands + 1))
    tb.end("DstAlphaBand")
    elem = tb.close()
    warp_options.append(elem)
    tb = ElementTree.TreeBuilder()
    tb.start("Option", {"name": "INIT_DEST"})
    tb.data("0")
    tb.end("Option")
    elem = tb.close()
    warp_options.append(elem)

    return ElementTree.tostring(vrt_root).decode()


class TileDetail(object):
    tx = 0
    ty = 0
    tz = 0
    rx = 0
    ry = 0
    rxsize = 0
    rysize = 0
    wx = 0
    wy = 0
    wxsize = 0
    wysize = 0

    def __init__(self, **kwargs):
        for key in kwargs:
            if hasattr(self, key):
                setattr(self, key, kwargs[key])


class TileJobInfo(object):
    srcFile = ""
    nbDataBands = 0
    outputFilePath = ""
    tminmax = []
    tminz = 0
    tmaxz = 0
    outGeoTrans = []

    def __init__(self, **kwargs):
        for key in kwargs:
            if hasattr(self, key):
                setattr(self, key, kwargs[key])


class TileJobsMaker(object):
    def __init__(self, inputFile, outputFolder, options):
        # 默认只支持包含RGB波段的数据
        self.dataBandsCount = 4
        # 流程采用gdal的vrt模式,加快运行(类似g2m)
        self.vrtFilename = os.path.join(
            tempfile.mkdtemp(), str(uuid4()) + '.vrt')
        self.inputFile = inputFile
        self.outputFolder = outputFolder
        self.options = options
        # 格式化缩放范围
        self.tminz = -1
        if self.options.zoom:
            minmax = self.options.zoom.split('-', 1)
            minmax.extend([''])
            zoom_min, zoom_max = minmax[:2]
            self.tminz = int(zoom_min)
            if zoom_max:
                self.tmaxz = int(zoom_max)
            else:
                self.tmaxz = int(zoom_min)

    def updateNoDataValue(self):
        """更新nodata的地方为透明"""
        def gdalVrtWarp(options, key, value):
            """vrt文件修改"""
            tb = ElementTree.TreeBuilder()
            tb.start("Option", {"name": key})
            tb.data(value)
            tb.end("Option")
            elem = tb.close()
            options.insert(0, elem)

        tempFile = tempfile.mktemp('-TileJobsMaker.vrt')
        self.warpedDataset.GetDriver().CreateCopy(tempFile, self.warpedDataset)
        with open(tempFile, 'r', encoding='utf-8') as f:
            vrtString = f.read()
            vrtRoot = ElementTree.fromstring(vrtString)
            options = vrtRoot.find("GDALWarpOptions")
            # 设定数据集的每一个像素初始值都为no_data
            gdalVrtWarp(options, "INIT_DEST", "NO_DATA")
            # 当所有波段都符合no_data时,将整个波段都视为no_data,而不将每个波段独立对待
            gdalVrtWarp(options, "UNIFIED_SRC_NODATA", "YES")
            vrtString = ElementTree.tostring(vrtRoot).decode()
        with open(tempFile, 'w') as f:
            f.write(vrtString)
        # 加载修改后的vrt文件
        correctedDataset = gdal.Open(tempFile)
        os.unlink(tempFile)
        # 设置no_data值为透明(RGBA四个波段都是0)
        correctedDataset.SetMetadataItem('NODATA_VALUES', '0 0 0 0')
        self.warpedDataset = correctedDataset

    def updateAlphaForNonAlphaData(self):
        warpedDataset = self.warpedDataset
        if warpedDataset.RasterCount in [1, 3]:
            tempfilename = gettempfilename('-gdal2tiles.vrt')
            warpedDataset.GetDriver().CreateCopy(tempfilename, warpedDataset)
            with open(tempfilename) as f:
                orig_data = f.read()
            alpha_data = add_alpha_band_to_string_vrt(orig_data)
            with open(tempfilename, 'w') as f:
                f.write(alpha_data)
            warpedDataset = gdal.Open(tempfilename)
            os.unlink(tempfilename)
        self.warpedDataset = warpedDataset

    def openData(self):
        gdal.AllRegister()
        self.mercator = GlobalMercator()
        inputDataset = gdal.Open(self.inputFile, gdal.GA_ReadOnly)
        check(not inputDataset, "数据无法打开")
        check(inputDataset.RasterCount == 0, "数据无波段")
        GetGeoTransform = inputDataset.GetGeoTransform()
        gcpCount = inputDataset.GetGCPCount()
        check(GetGeoTransform == (0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
              and gcpCount == 0, "数据缺少空间信息")
        # 强制将数据投影到web墨卡托投影,因为一般我们都是把瓦片发布为互联网服务的,3857无疑是最方便的
        inputSrs = osr.SpatialReference()
        if self.options.s_srs:
            inputSrs.SetFromUserInput(self.options.s_srs)
        else:
            inputSrs.ImportFromWkt(inputDataset.GetProjection())
        outputSrs = osr.SpatialReference()
        outputSrs.ImportFromEPSG(3857)
        # 投影.在这里,操作的数据其实已经变为vrt格式
        self.warpedDataset = gdal.AutoCreateWarpedVRT(inputDataset,
                                                      inputSrs.ExportToWkt(),
                                                      outputSrs.ExportToWkt())

        # 强制将nodata值设为透明
        inNodata = []
        for i in range(1, inputDataset.RasterCount+1):
            rasterNoData = inputDataset.GetRasterBand(i).GetNoDataValue()
            if rasterNoData is not None:
                inNodata.append(rasterNoData)
        if inNodata:
            self.updateNoDataValue()
        else:
            self.updateAlphaForNonAlphaData()
        # 火星偏移
        if self.options.gcj02:
            geotrans = self.warpedDataset.GetGeoTransform()
            lat, lng = self.mercator.MetersToLatLon(geotrans[0], geotrans[3])
            lat, lng = self.mercator.WGS84ToGCJ02(lat, lng)
            x, y = self.mercator.LatLonToMeters(lat, lng)
            warpedGeotrans = [x, geotrans[1], 0, y, 0, geotrans[5]]
            self.warpedDataset.SetGeoTransform(warpedGeotrans)
        # 将vrt格式的数据集写入到指定的vrt文件中,供后期使用
        self.warpedDataset.GetDriver().CreateCopy(self.vrtFilename,
                                                  self.warpedDataset)
        outGeotrans = self.warpedDataset.GetGeoTransform()
        check((outGeotrans[2], outGeotrans[4]) != (0, 0), "不支持变形后的数据")
        print(outGeotrans[1])
        # 自动计算缩放
        if self.tminz == -1:
            self.tminz = self.mercator.ZoomForPixelSize(
                outGeotrans[1] *
                max(self.warpedDataset.RasterXSize,
                    self.warpedDataset.RasterYSize) /
                float(TILESIZE))
            tmaxz = self.mercator.ZoomForPixelSize(outGeotrans[1])
            if tmaxz < 6:
                tmaxz = 6
            self.tmaxz = tmaxz
        # 计算元数据的四至,因为已经投影为web墨卡托,这里的单位是米
        self.ominx = outGeotrans[0]
        self.omaxx = outGeotrans[0] + \
            self.warpedDataset.RasterXSize * outGeotrans[1]
        self.omaxy = outGeotrans[3]
        self.ominy = outGeotrans[3] - \
            self.warpedDataset.RasterYSize * outGeotrans[1]
        self.tminmax = list(range(0, 32))
        # 计算每一缩放级,瓦片的行列号范围
        # 这里可以只计算用户指定的缩放范围,但其实影响不大
        for tz in range(0, 32):
            _tz = int(tz - (TILESIZE / 256 - 1))
            tminx, tminy = self.mercator.MetersToTile(
                self.ominx, self.ominy, _tz)
            tmaxx, tmaxy = self.mercator.MetersToTile(
                self.omaxx, self.omaxy, _tz)
            tminx, tminy = max(0, tminx), max(0, tminy)
            tmaxx, tmaxy = min(2**_tz - 1, tmaxx), min(2**_tz - 1, tmaxy)
            self.tminmax[tz] = (tminx, tminy, tmaxx, tmaxy)

    def makeMetadata(self):
        # 以经纬度的方式计算瓦片的四至,供前端地图跳转使用
        minlat, minlng = self.mercator.MetersToLatLon(self.ominx, self.ominy)
        maxlat, maxlng = self.mercator.MetersToLatLon(self.omaxx, self.omaxy)
        minlat, minlng = max(-85.05112878, minlat), max(-180.0, minlng)
        maxlat, maxlng = min(85.05112878, maxlat), min(180.0, maxlng)
        bbox = [[minlat, minlng], [maxlat, maxlng]]
        metadata = {"bbox": bbox, "tilesize": TILESIZE, 'minZoom': self.tminz,
                    'maxZoom': self.tmaxz, "gcj02": self.options.gcj02}
        with open(os.path.join(self.outputFolder, 'metadata.json'), 'w') as f:
            json.dump(metadata, f, indent=4)

    def makeBaseTiles(self):
        # 最高缩放级别的瓦片行列号范围
        tminx, tminy, tmaxx, tmaxy = self.tminmax[self.tmaxz]
        tileDetails = []
        tz = self.tmaxz
        # y倒序,从左上角开始切图
        for ty in range(tmaxy, tminy - 1, -1):
            for tx in range(tminx, tmaxx + 1):
                # tz,yx,ty对应着行列号中的x/y/z
                ty_final = GoogleTile(tz,ty)
                tilefilename = os.path.join(self.outputFolder, str(
                    tz), str(tx), "%s.%s" % (ty, TILEEXT))
                if not os.path.exists(os.path.dirname(tilefilename)):
                    os.makedirs(os.path.dirname(tilefilename))
                # 兼容大瓦片
                _tz = int(tz - (TILESIZE / 256 - 1))
                # 计算该瓦片的投影经纬度范围
                b = self.mercator.TileBounds(tx, ty, _tz)
                # 获取该瓦片具体的各种偏移参数
                rb, wb = self.geoQuery(b[0], b[3], b[2], b[1])
                rx, ry, rxsize, rysize = rb
                wx, wy, wxsize, wysize = wb
                tileDetails.append(
                    TileDetail(
                        tx=tx,
                        ty=ty,
                        tz=tz,
                        rx=rx,
                        ry=ry,
                        rxsize=rxsize,
                        rysize=rysize,
                        wx=wx,
                        wy=wy,
                        wxsize=wxsize,
                        wysize=wysize,
                    ))
        conf = TileJobInfo(
            srcFile=self.vrtFilename,
            nbDataBands=self.dataBandsCount,
            outputFilePath=self.outputFolder,
            tminmax=self.tminmax,
            tminz=self.tminz,
            tmaxz=self.tmaxz,
        )
        return conf, tileDetails

    def geoQuery(self, ulx, uly, lrx, lry):
        ds = self.warpedDataset
        geotran = ds.GetGeoTransform()
        # geotran[0/3]是tif左上角点x/y
        # geotran[1/5]是像源宽/高
        # 计算该瓦片的左上角在源图上的x/y像素偏移量
        rx = int((ulx - geotran[0]) / geotran[1] + 0.001)
        ry = int((uly - geotran[3]) / geotran[5] + 0.001)
        # 计算该瓦片在源图上的像素宽度
        rxsize = int((lrx - ulx) / geotran[1] + 0.5)
        rysize = int((lry - uly) / geotran[5] + 0.5)
        # 窗口尺寸.4倍于瓦片尺寸,提高缩放重采样时的瓦片效果
        wxsize, wysize = 4 * TILESIZE, 4 * TILESIZE
        wx = 0
        if rx < 0:
            rxshift = abs(rx)
            wx = int(wxsize * (float(rxshift) / rxsize))
            # 等比例缩减多少
            wxsize = wxsize - wx
            rxsize = rxsize - int(rxsize * (float(rxshift) / rxsize))
            rx = 0
        if rx + rxsize > ds.RasterXSize:
            wxsize = int(wxsize * (float(ds.RasterXSize - rx) / rxsize))
            rxsize = ds.RasterXSize - rx
        wy = 0
        if ry < 0:
            ryshift = abs(ry)
            wy = int(wysize * (float(ryshift) / rysize))
            wysize = wysize - wy
            rysize = rysize - int(rysize * (float(ryshift) / rysize))
            ry = 0
        if ry + rysize > ds.RasterYSize:
            wysize = int(wysize * (float(ds.RasterYSize - ry) / rysize))
            rysize = ds.RasterYSize - ry
        return (rx, ry, rxsize, rysize), (wx, wy, wxsize, wysize)


class ProgressBar(object):
    def __init__(self, total_items, title):
        sys.stdout.write("%s 共%d张 \n" % (title, total_items))
        self.total_items = total_items
        self.nb_items_done = 0
        self.current_progress = 0
        self.STEP = 2.5

    def start(self):
        sys.stdout.write("0")

    def updateProgress(self, nb_items=1):
        self.nb_items_done += nb_items
        progress = float(self.nb_items_done) / self.total_items * 100
        if progress >= self.current_progress + self.STEP:
            done = False
            while not done:
                if self.current_progress + self.STEP <= progress:
                    self.current_progress += self.STEP
                    if self.current_progress % 10 == 0:
                        sys.stdout.write(str(int(self.current_progress)))
                        if self.current_progress == 100:
                            sys.stdout.write("\n")
                    else:
                        sys.stdout.write(".")
                else:
                    done = True
        sys.stdout.flush()


class SingleProcessTiling(object):
    def __init__(self, inputFile, outputFolder, options):
        self.inputFile = inputFile
        self.outputFolder = outputFolder
        self.options = options
        self.total = 0
        self.tiling()
        self.createOverviewTiles()
        shutil.rmtree(os.path.dirname(self.tileJobInfo.srcFile))
        if COMPRESS:
            self.compressPng()

    def tiling(self):
        tileDetails = self.workerTileDetails()
        tilecount = len(tileDetails)
        self.total += tilecount
        progressBar = ProgressBar(tilecount, '切割顶层瓦片')
        progressBar.start()
        for tileDetail in tileDetails:
            self.createBaseTile(tileDetail)
            progressBar.updateProgress()

    def createBaseTile(self, tileDetail, queue=None):
        gdal.AllRegister()
        # 获取任务参数
        tileJobInfo = self.tileJobInfo
        output = tileJobInfo.outputFilePath
        tilebands = tileJobInfo.nbDataBands
        # 打开vrt文件作为数据集,读取模式,在多进程模式下不会出现锁?
        ds = gdal.Open(tileJobInfo.srcFile, gdal.GA_ReadOnly)
        memDrv = gdal.GetDriverByName('MEM')
        outDrv = gdal.GetDriverByName(TILEDRIVER)
        alphaband = ds.GetRasterBand(1).GetMaskBand()
        tx = tileDetail.tx
        ty = tileDetail.ty
        tz = tileDetail.tz
        rx = tileDetail.rx
        ry = tileDetail.ry
        rxsize = tileDetail.rxsize
        rysize = tileDetail.rysize
        wx = tileDetail.wx
        wy = tileDetail.wy
        wxsize = tileDetail.wxsize
        wysize = tileDetail.wysize
        # '窗口'数据集尺寸就是4倍瓦片大小
        querysize = 4 * TILESIZE
        ty_final = GoogleTile(tz,ty)
        tilefilename = os.path.join(output, str(
            tz), str(tx), "%s.%s" % (ty_final, TILEEXT))
        # 最终要写入的数据集
        dstile = memDrv.Create('', TILESIZE, TILESIZE, tilebands)
        data = alpha = None
        if rxsize != 0 and rysize != 0 and wxsize != 0 and wysize != 0:
            # 根据上文获取到的参数,读取每一张瓦片对应的数据
            data = ds.ReadRaster(rx, ry, rxsize, rysize, wxsize,
                                 wysize, band_list=list(range(1, tilebands)))
            # alpha波段直接创建
            alpha = alphaband.ReadRaster(
                rx, ry, rxsize, rysize, wxsize, wysize)
            if data:
                dsquery = memDrv.Create('', querysize, querysize, tilebands)
                # 先将数据读取到窗口数据集中
                dsquery.WriteRaster(wx, wy, wxsize, wysize,
                                    data, band_list=list(range(1, tilebands)))
                dsquery.WriteRaster(wx, wy, wxsize, wysize,
                                    alpha, band_list=[tilebands])
                # 重采样到目标数据集
                self.scaleQueryToTile(dsquery, dstile, tilefilename)
                del dsquery
        del ds
        del data
        # 目标数据集导出png图片到指定位置
        outDrv.CreateCopy(tilefilename, dstile, strict=0)
        del dstile
        # 如果是多进程模式则向主进程传递进度
        if queue:
            queue.put("tile %s %s %s" % (tx, ty, tz))

    def workerTileDetails(self):
        # 打开数据
        tileJobsMaker = TileJobsMaker(
            self.inputFile, self.outputFolder, self.options)
        tileJobsMaker.openData()
        # 生成元数据文件
        tileJobsMaker.makeMetadata()
        # 生成任务列表
        conf, tileDetails = tileJobsMaker.makeBaseTiles()
        self.tileJobInfo = conf
        return tileDetails

    def compressPng(self):
        progressBar = ProgressBar(self.total, '压缩全部瓦片')
        progressBar.start()
        pngquant.config(quant_file=QUANTFILE, min_quality=70,
                        max_quality=95, speed=10)
        for root, dirs, files in os.walk(self.outputFolder, True):
            for file in files:
                if file.endswith('.png'):
                    realPath = os.path.join(root, file)
                    self.compress(realPath)
                    progressBar.updateProgress()

    def compress(self, imgPath, queue=None):
        if queue is not None:
            # 多进程下每个进程单独创建临时路径
            import uuid
            tmp_file = os.path.join(tempfile.gettempdir(
            ), '{0}.quant.tmp.png'.format(uuid.uuid4().hex))
            pngquant.config(quant_file=QUANTFILE, min_quality=70,
                            max_quality=95, speed=5, tmp_file=tmp_file)
            queue.put('imgPath')
        # 不保留空瓦片,直接删除
        if os.path.getsize(imgPath) == EMPTY[str(TILESIZE)]:
            os.remove(imgPath)
        else:
            pngquant.quant_image(imgPath, imgPath)

    def scaleQueryToTile(self, dsquery, dstile, tilefilename=''):
        """从'窗口'数据集使用'average'算法重采样到目标数据集"""
        tilebands = dstile.RasterCount
        for i in range(1, tilebands + 1):
            res = gdal.RegenerateOverview(dsquery.GetRasterBand(
                i), dstile.GetRasterBand(i), 'average')
            check(res != 0, "概览生成失败 %s,%d" % (tilefilename, res))

    def createOverviewTiles(self):
        tileJobInfo = self.tileJobInfo
        memDriver = gdal.GetDriverByName('MEM')
        outDriver = gdal.GetDriverByName(TILEDRIVER)
        tilebands = tileJobInfo.nbDataBands
        tcount = 0
        for tz in range(tileJobInfo.tmaxz - 1, tileJobInfo.tminz - 1, -1):
            tminx, tminy, tmaxx, tmaxy = tileJobInfo.tminmax[tz]
            tcount += (1 + abs(tmaxx - tminx)) * (1 + abs(tmaxy - tminy))
        if tcount == 0:
            return
        self.total += tcount
        progressBar = ProgressBar(tcount, '切割下层瓦片')
        progressBar.start()
        for tz in range(tileJobInfo.tmaxz - 1, tileJobInfo.tminz - 1, -1):
            tminx, tminy, tmaxx, tmaxy = tileJobInfo.tminmax[tz]
            # 遍历所有底层瓦片
            for ty in range(tmaxy, tminy - 1, -1):
                for tx in range(tminx, tmaxx + 1):
                    ty_final = GoogleTile(tz,ty)
                    tilefilename = os.path.join(self.outputFolder, str(
                        tz), str(tx), "%s.%s" % (ty_final, TILEEXT))
                    if not os.path.exists(os.path.dirname(tilefilename)):
                        os.makedirs(os.path.dirname(tilefilename))
                    dsquery = memDriver.Create(
                        '', 2 * TILESIZE, 2 * TILESIZE, tilebands)
                    dstile = memDriver.Create(
                        '', TILESIZE, TILESIZE, tilebands)
                    # 每级的行列号都是它大一级的2分之1
                    # 每张瓦片都与比它大一级的4张对应瓦片所示范围相同,所以根据这4张瓦片就能拼接出本级瓦片
                    for y in range(2 * ty, 2 * ty + 2):
                        for x in range(2 * tx, 2 * tx + 2):
                            minx, miny, maxx, maxy = tileJobInfo.tminmax[tz + 1]
                            # 只拼接有数据的瓦片
                            if x >= minx and x <= maxx and y >= miny and y <= maxy:
                                y_final = GoogleTile(tz+1,y)
                                path = os.path.join(self.outputFolder, str(
                                    tz + 1), str(x), "%s.%s" % (y_final, TILEEXT))
                                # 读取4张中的每一张瓦片
                                dsquerytile = gdal.Open(path, gdal.GA_ReadOnly)
                                # 把4张瓦片放到对应的位置
                                if (ty == 0 and y == 1) or (ty != 0 and(y % (2 * ty)) != 0):
                                    tileposy = 0
                                else:
                                    tileposy = TILESIZE
                                if tx:
                                    tileposx = x % (2 * tx) * TILESIZE
                                elif tx == 0 and x == 1:
                                    tileposx = TILESIZE
                                else:
                                    tileposx = 0
                                # 读取瓦片再在'窗口'中绘制
                                tempRaseter = dsquerytile.ReadRaster(
                                    0, 0, TILESIZE, TILESIZE)
                                dsquery.WriteRaster(tileposx, tileposy, TILESIZE, TILESIZE, tempRaseter, band_list=list(
                                    range(1, tilebands + 1)))
                    # 重采样后写入本地对应文件
                    self.scaleQueryToTile(
                        dsquery, dstile, tilefilename=tilefilename)
                    outDriver.CreateCopy(tilefilename, dstile, strict=0)
                    progressBar.updateProgress()


class MultiProcessTiling(SingleProcessTiling):
    """
    多进程生成瓦片并压缩
    """

    def __init__(self, inputFile, outputFolder, options):
        super().__init__(inputFile, outputFolder, options)

    def tiling(self):
        processes = self.options.processes or cpu_count()
        tileDetails = self.workerTileDetails()
        tilecount = len(tileDetails)
        self.total += tilecount
        manager = Manager()
        queue = manager.Queue()
        pool = Pool(processes=processes)
        for tileDetail in tileDetails:
            pool.apply_async(self.createBaseTile,
                             (tileDetail, ), {"queue": queue})
        p = Process(target=self.progressPrinter,
                    args=[queue, tilecount, '切割顶层瓦片'])
        p.start()
        pool.close()
        pool.join()
        p.join()

    def progressPrinter(self, queue, nb_jobs, title):
        """供多进程模式下打印进度"""
        pb = ProgressBar(nb_jobs, title)
        pb.start()
        for _ in range(nb_jobs):
            queue.get()
            pb.updateProgress()
            queue.task_done()

    def compressPng(self):
        """多进程压缩png"""
        processes = self.options.processes or cpu_count()
        manager = Manager()
        queue = manager.Queue()
        pool = Pool(processes=processes)
        p = Process(target=self.progressPrinter, args=[
                    queue, self.total, '压缩全部瓦片'])
        for root, dirs, files in os.walk(self.outputFolder, True):
            for file in files:
                if file.endswith('.png'):
                    realPath = os.path.join(root, file)
                    pool.apply_async(self.compress, args=(realPath, queue,),)
        p.start()
        pool.close()
        pool.join()
        p.join()


def process_args(argv):
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-z", dest='zoom', metavar='切割级别', help="\t如 '12-20'")
    parser.add_argument("-p", dest='processes', metavar='进程数',
                        type=int, default=1, help='\t默认单进程')
    parser.add_argument("-g", dest='gcj02', action='store_true',
                        default=False, help='\t火星偏移')
    parser.add_argument("-i", dest='input', metavar='tif文件', required=True)
    parser.add_argument("-s", dest='s_srs', metavar='输入文件的参考希')
    parser.add_argument("-o", dest='output', metavar="输出文件夹",
                        help="\t可选,默认在输入文件夹下的同名文件夹")
    args = parser.parse_args(argv)
    inputFile = args.input
    check(not os.path.isfile(inputFile), "%s不存在或非文件" % inputFile)
    outputFolder = args.output
    if not outputFolder:
        tifname = os.path.basename(inputFile).split('.')[0]
        outputFolder = os.path.join(os.path.dirname(
            os.path.abspath(inputFile)), tifname)
    if not os.path.exists(outputFolder):
        os.makedirs(outputFolder)
    return inputFile, outputFolder, args


def main():
    import time
    start = int(time.time())
    argv = gdal.GeneralCmdLineProcessor(sys.argv)
    inputFile, outputFolder, options = process_args(argv[1:])
    if options.processes == 1:
        SingleProcessTiling(inputFile, outputFolder, options)
    else:
        MultiProcessTiling(inputFile, outputFolder, options)
    print('全部结束,用时:%d秒' % (int(time.time())-start))


if __name__ == '__main__':
    main()

你可能感兴趣的:(GDAL)