python gdal 简单实现图像拼接

利用gdal将下列示例影像合并
python gdal 简单实现图像拼接_第1张图片

图像拼接

  • 一、读取影像
  • 二、横向合并矩阵
  • 三、利用相片坐标计算坐标位置
  • 四、计算仿射变换参数
  • 五、输出影像
  • 完整代码

一、读取影像

利用gdal读取遥感影像

from osgeo import gdal

def read_img(filename):
    dataset = gdal.Open(filename)  # 打开文件
    if dataset == None:
        raise Exception(f"cant find/open {filename}")
    im_width = dataset.RasterXSize  # 栅格矩阵的列数
    im_height = dataset.RasterYSize  # 栅格矩阵的行数
    im_Band = dataset.RasterCount  # 栅格矩阵的波段数

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0, 0, im_width, im_height)  # 将数据写成数组,对应栅格矩阵

    del dataset  # 关闭对象,文件dataset
    return im_proj, im_geotrans, im_data, im_width, im_height, im_Band

二、横向合并矩阵

利用numpy合并矩阵

    im_proj, im_geotrans, im_data1, im_width, im_height, im_Band = read_img(filepath1)
    _, _, im_data2, _, _, _ = read_img(filepath2)
    _, _, im_data3, _, _, _ = read_img(filepath3)
    _, _, im_data4, _, _, _ = read_img(filepath4)
    _, _, im_data5, _, _, _ = read_img(filepath5)
    im_data = np.concatenate((im_data1, im_data2, im_data3, im_data4, im_data5), axis=2)
    print(f"im_data1.shape:{im_data1.shape}")
    print(f"im_data2.shape:{im_data2.shape}")
    print(f"im_data3.shape:{im_data3.shape}")
    print(f"im_data4.shape:{im_data4.shape}")
    print(f"im_data5.shape:{im_data5.shape}")
    print(f"im_data.shape :{im_data.shape}")

输出结果

im_data1.shape:(4, 1000, 1000)
im_data2.shape:(4, 1000, 1000)
im_data3.shape:(4, 1000, 1000)
im_data4.shape:(4, 1000, 1000)
im_data5.shape:(4, 1000, 1000)
im_data.shape :(4, 1000, 5000)

三、利用相片坐标计算坐标位置

根据仿射变换参数计算坐标位置

def imagexy2geo(im_geotrans, row, col):
    """
    相片坐标计算坐标位置
    :param im_geotrans:图像放射变换参数
    :param row: 行数
    :param col: 列数
    :return: 坐标位置
    """
    px = im_geotrans[0] + col * im_geotrans[1] + row * im_geotrans[2]
    py = im_geotrans[3] + col * im_geotrans[4] + row * im_geotrans[5]
    return [px, py]

四、计算仿射变换参数

根据新坐标位置重算仿射变换参数

def setGeotrans(im_geotrans, row, col):
    """
    根据影像大小重算仿射变换参数
    :param im_geotrans: 
    :param row: 行数
    :param col: 列数
    :return: 仿射变换参数
    """
    coords00 = imagexy2geo(im_geotrans, 0, 0)
    coords01 = imagexy2geo(im_geotrans, row, 0)
    coords10 = imagexy2geo(im_geotrans, 0, col)

    trans = [0 for i in range(6)]
    trans[0] = coords00[0]
    trans[3] = coords00[1]
    trans[2] = (coords01[0] - trans[0]) / row
    trans[5] = (coords01[1] - trans[3]) / row
    trans[1] = (coords10[0] - trans[0]) / col
    trans[4] = (coords10[1] - trans[3]) / col
    return trans

五、输出影像

根据新坐标位置重算仿射变换参数

DType2GDAL = {"uint8": gdal.GDT_Byte,
              "uint16": gdal.GDT_UInt16,
              "int16": gdal.GDT_Int16,
              "uint32": gdal.GDT_UInt32,
              "int32": gdal.GDT_Int32,
              "float32": gdal.GDT_Float32,
              "float64": gdal.GDT_Float64,
              "cint16": gdal.GDT_CInt16,
              "cint32": gdal.GDT_CInt32,
              "cfloat32": gdal.GDT_CFloat32,
              "cfloat64": gdal.GDT_CFloat64}


def write_img(filename, im_proj, im_geotrans, im_data):
    # 判断栅格数据的数据类型
    if im_data.dtype.name in DType2GDAL:
        datatype = DType2GDAL[im_data.dtype.name]
    else:
        datatype = gdal.GDT_Float32
    # 判读数组维数
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1, im_data.shape
    # 创建文件
    if not pathlib.Path(filename).parent.exists():
        pathlib.Path(filename).parent.mkdir(parents=True, exist_ok=True)
    driver = gdal.GetDriverByName("GTiff")  # 数据类型必须有,因为要计算需要多大内存空间
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
    dataset.SetProjection(im_proj)  # 写入投影
    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)  # 写入数组数据
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset


完整代码

# -*- coding: utf-8 -*-
"""
 @ time: 2022/11/11 13:49
 @ file:
 @ author: QYD2001
"""
import pathlib

import numpy as np
from osgeo import gdal

DType2GDAL = {"uint8": gdal.GDT_Byte,
              "uint16": gdal.GDT_UInt16,
              "int16": gdal.GDT_Int16,
              "uint32": gdal.GDT_UInt32,
              "int32": gdal.GDT_Int32,
              "float32": gdal.GDT_Float32,
              "float64": gdal.GDT_Float64,
              "cint16": gdal.GDT_CInt16,
              "cint32": gdal.GDT_CInt32,
              "cfloat32": gdal.GDT_CFloat32,
              "cfloat64": gdal.GDT_CFloat64}


def write_img(filename, im_proj, im_geotrans, im_data):
    # 判断栅格数据的数据类型
    if im_data.dtype.name in DType2GDAL:
        datatype = DType2GDAL[im_data.dtype.name]
    else:
        datatype = gdal.GDT_Float32
    # 判读数组维数
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1, im_data.shape
    # 创建文件
    if not pathlib.Path(filename).parent.exists():
        pathlib.Path(filename).parent.mkdir(parents=True, exist_ok=True)
    driver = gdal.GetDriverByName("GTiff")  # 数据类型必须有,因为要计算需要多大内存空间
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
    dataset.SetProjection(im_proj)  # 写入投影
    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)  # 写入数组数据
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset


def read_img(filename):
    dataset = gdal.Open(filename)  # 打开文件
    if dataset == None:
        raise Exception(f"cant find/open {filename}")
    im_width = dataset.RasterXSize  # 栅格矩阵的列数
    im_height = dataset.RasterYSize  # 栅格矩阵的行数
    im_Band = dataset.RasterCount  # 栅格矩阵的行数

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0, 0, im_width, im_height)  # 将数据写成数组,对应栅格矩阵

    del dataset  # 关闭对象,文件dataset
    return im_proj, im_geotrans, im_data, im_width, im_height, im_Band


def imagexy2geo(im_geotrans, row, col):
    """
    相片坐标计算坐标位置
    :param im_geotrans:图像放射变换参数
    :param row: 行数
    :param col: 列数
    :return: 坐标位置
    """
    px = im_geotrans[0] + col * im_geotrans[1] + row * im_geotrans[2]
    py = im_geotrans[3] + col * im_geotrans[4] + row * im_geotrans[5]
    return [px, py]


def setGeotrans(im_geotrans, row, col):
    """
    根据影像大小重算仿射变换参数
    :param im_geotrans:
    :param row:
    :param col:
    :return: 仿射变换参数
    """
    coords00 = imagexy2geo(im_geotrans, 0, 0)
    coords01 = imagexy2geo(im_geotrans, row, 0)
    coords10 = imagexy2geo(im_geotrans, 0, col)

    trans = [0 for i in range(6)]
    trans[0] = coords00[0]
    trans[3] = coords00[1]
    trans[2] = (coords01[0] - trans[0]) / row
    trans[5] = (coords01[1] - trans[3]) / row
    trans[1] = (coords10[0] - trans[0]) / col
    trans[4] = (coords10[1] - trans[3]) / col
    return trans


if __name__ == '__main__':
    filepath1 = r"E:\RsData\1108\1m\1m\LNPJ210912_GF2PMS0005882472_001_001.tif"
    filepath2 = r"E:\RsData\1108\1m\1m\LNPJ210912_GF2PMS0005882472_001_002.tif"
    filepath3 = r"E:\RsData\1108\1m\1m\LNPJ210912_GF2PMS0005882472_001_003.tif"
    filepath4 = r"E:\RsData\1108\1m\1m\LNPJ210912_GF2PMS0005882472_001_004.tif"
    filepath5 = r"E:\RsData\1108\1m\1m\LNPJ210912_GF2PMS0005882472_001_005.tif"
    outpath = r"E:\RsData\1108\1m\1m\LNPJ210912_GF2PMS0005882472_001.tif"

    im_proj, im_geotrans, im_data1, im_width, im_height, im_Band = read_img(filepath1)
    _, _, im_data2, _, _, _ = read_img(filepath2)
    _, _, im_data3, _, _, _ = read_img(filepath3)
    _, _, im_data4, _, _, _ = read_img(filepath4)
    _, _, im_data5, _, _, _ = read_img(filepath5)
    im_data = np.concatenate((im_data1, im_data2, im_data3, im_data4, im_data5), axis=2)
    print(f"im_data1.shape:{im_data1.shape}")
    print(f"im_data2.shape:{im_data2.shape}")
    print(f"im_data3.shape:{im_data3.shape}")
    print(f"im_data4.shape:{im_data4.shape}")
    print(f"im_data5.shape:{im_data5.shape}")
    print(f"im_data.shape :{im_data.shape}")

    geotrans = setGeotrans(im_geotrans, im_data.shape[0], im_data.shape[1])
    write_img(outpath, im_proj, geotrans, im_data)


输出结果
python gdal 简单实现图像拼接_第2张图片
圆满完成任务@-@

你可能感兴趣的:(python,numpy,gdal)