MMsegmentation自定义预测代码

能够将完整大图切成小块,分别预测后再拼接成大图

import os
import numpy as np
import cv2
from tqdm import tqdm
import rasterio as rio
from mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv
from rasterio import windows
import matplotlib.pyplot as plt
from itertools import product

def split_image(image_path, tile_size):
    """
    将图像分割为指定大小的小图。
    :param image_path: 要分割的图像的路径。
    :param tile_size: 小图的大小(以像素为单位)。
    :return: 一个包含所有小图数据的列表。
    """
    tiles = []
    with rio.open(image_path) as inds:
        nols, nrows = inds.meta['width'], inds.meta['height']
        total_meta = inds.meta.copy()
        offsets = product(range(0, nols, tile_size), range(0, nrows, tile_size))
        splitlen = len(range(0, nrows, tile_size))
        big_window = windows.Window(col_off=0, row_off=0, width=nols, height=nrows)
        # i = 1
        for col_off, row_off in offsets:
            # print('col_off is ',col_off)
            # print('row_off is ',row_off)
            window = windows.Window(col_off=col_off, row_off=row_off, width=tile_size, height=tile_size).intersection(big_window)
            transform = windows.transform(window, inds.transform)
            meta = inds.meta.copy()
            meta['transform'] = transform
            meta['width'], meta['height'] = window.width, window.height
            print('tile shape is ',inds.read(window=window).shape)
            title = inds.read(window=window)
            title_4bands = np.einsum('ijk->jki', title)
            tiles.append(title_4bands)
            
            # print('------------'+str(i)+'----------------')
            # i += 1
    return tiles,splitlen,total_meta

def join_image(tiles,splitlen,meta,output_path):
    axis_y_merge = []
    for num,tile in enumerate(tiles):
        if (num+1) % splitlen == 0:
            temp_tiles_data = tiles[(num-(splitlen-1)):(num+1)]
            print(len(temp_tiles_data))
            axis_y_merge.append(np.concatenate(temp_tiles_data,axis=1))
            
    print(len(axis_y_merge))
    big_merge = np.concatenate(axis_y_merge,axis=2)
    # big_merge = big_merge[:3]
    print(big_merge.shape)
    meta.update(count=big_merge.shape[0])
    with rio.open(output_path,'w',**meta) as outds:
        outds.write(big_merge)

def process_single_img(img_path,output_path, save=True):
    opacity = 0.3
    tiles_data,splitlen,total_meta = split_image(img_path,tile_size=896)
    pred_mask_bgrs = []
    for img_4band in tiles_data:
        # 语义分割预测
        print('img 4 band',img_4band.shape)
        result = inference_model(model, img_4band)
        pred_mask = result.pred_sem_seg.data[0].cpu().numpy()

        # 将预测的整数ID,映射为对应类别的颜色
        pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
        for idx in palette_dict.keys():
            pred_mask_bgr[np.where(pred_mask==idx)] = palette_dict[idx]
        pred_mask_bgr = pred_mask_bgr.astype('uint8')
        pred_mask_bgr_out = np.einsum('ijk->kij', pred_mask_bgr)
        pred_mask_bgrs.append(pred_mask_bgr_out)
    join_image(pred_mask_bgrs,splitlen,total_meta,output_path)

if __name__ == "__main__":

    # 模型 config 配置文件
    config_file = r'/root/MMSegmentation_Tutorials-main/20230816/mmsegmentation/work_dirs/GF2MultiDataset-Swintrans/GF2MultiDataset_Swintrans_20230904.py'

    # 模型 checkpoint 权重文件
    checkpoint_file = r'/root/MMSegmentation_Tutorials-main/20230816/mmsegmentation/work_dirs/GF2MultiDataset-Swintrans/best_mIoU_iter_27700.pth'

    # 计算硬件
    # device = 'cpu'
    device = 'cuda:0'

    model = init_model(config_file, checkpoint_file, device=device)
    # 每个类别的 BGR 配色
    palette = [
        ['background', [120,0,0]],
        ['crop', [127,127,127]],
        ['road', [0,200,0]],
        ['water', [144,238,144]],
    ]

    palette_dict = {}
    for idx, each in enumerate(palette):
        palette_dict[idx] = each[1]
    output_dir = r'/root/MMSegmentation_Tutorials-main/20230816/mmsegmentation/outputs/testset-pred_gf6'
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    PATH_IMAGE = r'/root/autodl-tmp/test'
    for img_name in os.listdir(PATH_IMAGE):
        output_path = os.path.join(output_dir,img_path) 
        img_path = os.path.join(PATH_IMAGE,img_name)
        process_single_img(img_path,output_path)

你可能感兴趣的:(pytorch,深度学习,python)