python——遥感影像分块

python 遥感影像分块

在深度学习或其他一些应用中,往往需要对较大的遥感影像进行分块输出,并且保留原始的地理信息,以备后续使用。

本篇文章的目的即对遥感影像进行分块输出,按照网格顺序进行重新命名,其中边缘部分可能存在不足裁剪大小的情况,通过向前填补方式进行填充,故分块结果可能导致边缘部分存在一定重复率。

具体代码如下:

#!usr/bin/env python
# -*- coding: utf-8 -*-
"""
date:2022/9/2
author:甲戌_Tr
email: [email protected]
"""

import sys,os
from osgeo import gdal
import numpy as np
from pathlib import Path as Path

class TifCrop:
    def __init__(self,infile,crop_size,save_path,repete_rate=0):
        '''
        遥感影像分块函数
        :param infile: 输入tif文件
        :param crop_size: 分块大小,单值或元祖,int型。eg:200表示以 200*200个像元大小的方形进行分块,(100,200)表示以 100*200个像元大小的矩形进行分块
        :param repete_size:重复率, float, 其中值的范围为[0,1)之间, 默认值为0
        '''
        self.infile = infile
        self.crop_size = crop_size
        self.save_path = save_path
        self.repete_rate = repete_rate

        # crop_size 参数判断
        if not isinstance(crop_size, int):
            if not isinstance(crop_size, tuple):
                raise Exception('crop_size 输入参数错误')
            else:
                if not (isinstance(crop_size[0], int) and isinstance(crop_size[1], int)):
                    raise Exception('crop_size 输入参数错误')

        # repete_rate 参数判断
        if repete_rate >= 1 or repete_rate < 0:
            raise Exception('repete_rate 输出参数错误')

    def crop_tif(self):
        if isinstance(self.crop_size,tuple):
            crop_size_r = self.crop_size[0]
            crop_size_c = self.crop_size[1]
        else:
            crop_size_r = self.crop_size
            crop_size_c = self.crop_size

        repete_size_r = int(crop_size_r * (1 - self.repete_rate))
        repete_size_c = int(crop_size_c * (1 - self.repete_rate))

        ds = gdal.Open(self.infile)
        data = ds.ReadAsArray()

        geotrans = ds.GetGeoTransform()
        self.projection = ds.GetProjection()

        # 将单波段影像添加一个维度
        if len(data.shape) == 2:
            data = np.array([data])

        channel, rows, cols = data.shape

        # 向上取整
        col_num = int(np.ceil(cols / repete_size_c))
        row_num = int(np.ceil(rows / repete_size_r))

        # 循环读取
        # 边缘部分按照向前扩充原则进行提取

        # 当重复率较高或分块尺寸较小时,遇到边缘部分可能存在分割相同的情况,故用以下参数进行判断避免该情况发生
        start_point = (-1,-1)

        for i in range(col_num):
            for j in range(row_num):
                row_s = repete_size_r * j
                row_e = repete_size_r * j + crop_size_r

                # 是否超出边界判断
                if row_e > rows:
                    row_s = rows - crop_size_r
                    row_e = rows

                col_s = repete_size_c * i
                col_e = repete_size_c * i + crop_size_c

                # 是否超出边界判断
                if col_e > cols:
                    col_s = cols - crop_size_c
                    col_e = cols

                data_crop = data[:,row_s:row_e,col_s:col_e]

                # 判断输出内容是否与之前存在重复情况,非完全重叠部分再进行分块输出
                if (row_s,col_s) != start_point:
                    start_point = (row_s,col_s)
                    # 地理信息存放
                    new_geotrans = (geotrans[0] + geotrans[1] * col_s, geotrans[1], geotrans[2], geotrans[3] + geotrans[5] * row_s,geotrans[4],geotrans[5])

                    # 输出文件名称
                    out_file = self.save_path + os.sep + Path(self.infile).stem + '_' + str(j) + '_' + str(i) + Path(self.infile).suffix

                    self.tif_write(data_crop,new_geotrans,out_file)

    def tif_write(self,data,trans,ofile):
        '''
        tif写入
        :param data: 分块后数组
        :param trans: 更新后的geotransform,包括六参数
        :param ofile: 输出全路径
        :return: None
        '''
        # 数据类型获取
        if 'int8' in data.dtype.name:
            datatype = gdal.GDT_Byte
        elif 'int16' in data.dtype.name:
            datatype = gdal.GDT_UInt16
        else:
            datatype = gdal.GDT_Float32

        # 输出tif文件按照单波段或多波段划分
        bands,height,width = data.shape

        # 创建文件
        driver = gdal.GetDriverByName("GTiff")
        dataset = driver.Create(ofile, int(width), int(height), int(bands), datatype)

        if (dataset != None):
            dataset.SetGeoTransform(trans)  # 写入仿射变换参数
            dataset.SetProjection(self.projection)  # 写入投影
        for i in range(bands):
            dataset.GetRasterBand(i + 1).WriteArray(data[i])
        del dataset

if __name__ == '__main__':
    file = r'F:\input.tif'
    cropsize = 300
    savepath = r'F:\outpath'
    TifCrop(file,cropsize,savepath).crop_tif()

欢迎指正~

你可能感兴趣的:(python)