能够将完整大图切成小块,分别预测后再拼接成大图
import os
import numpy as np
import cv2
from tqdm import tqdm
import rasterio as rio
from mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv
from rasterio import windows
import matplotlib.pyplot as plt
from itertools import product
def split_image(image_path, tile_size):
"""
将图像分割为指定大小的小图。
:param image_path: 要分割的图像的路径。
:param tile_size: 小图的大小(以像素为单位)。
:return: 一个包含所有小图数据的列表。
"""
tiles = []
with rio.open(image_path) as inds:
nols, nrows = inds.meta['width'], inds.meta['height']
total_meta = inds.meta.copy()
offsets = product(range(0, nols, tile_size), range(0, nrows, tile_size))
splitlen = len(range(0, nrows, tile_size))
big_window = windows.Window(col_off=0, row_off=0, width=nols, height=nrows)
# i = 1
for col_off, row_off in offsets:
# print('col_off is ',col_off)
# print('row_off is ',row_off)
window = windows.Window(col_off=col_off, row_off=row_off, width=tile_size, height=tile_size).intersection(big_window)
transform = windows.transform(window, inds.transform)
meta = inds.meta.copy()
meta['transform'] = transform
meta['width'], meta['height'] = window.width, window.height
print('tile shape is ',inds.read(window=window).shape)
title = inds.read(window=window)
title_4bands = np.einsum('ijk->jki', title)
tiles.append(title_4bands)
# print('------------'+str(i)+'----------------')
# i += 1
return tiles,splitlen,total_meta
def join_image(tiles,splitlen,meta,output_path):
axis_y_merge = []
for num,tile in enumerate(tiles):
if (num+1) % splitlen == 0:
temp_tiles_data = tiles[(num-(splitlen-1)):(num+1)]
print(len(temp_tiles_data))
axis_y_merge.append(np.concatenate(temp_tiles_data,axis=1))
print(len(axis_y_merge))
big_merge = np.concatenate(axis_y_merge,axis=2)
# big_merge = big_merge[:3]
print(big_merge.shape)
meta.update(count=big_merge.shape[0])
with rio.open(output_path,'w',**meta) as outds:
outds.write(big_merge)
def process_single_img(img_path,output_path, save=True):
opacity = 0.3
tiles_data,splitlen,total_meta = split_image(img_path,tile_size=896)
pred_mask_bgrs = []
for img_4band in tiles_data:
# 语义分割预测
print('img 4 band',img_4band.shape)
result = inference_model(model, img_4band)
pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
# 将预测的整数ID,映射为对应类别的颜色
pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
for idx in palette_dict.keys():
pred_mask_bgr[np.where(pred_mask==idx)] = palette_dict[idx]
pred_mask_bgr = pred_mask_bgr.astype('uint8')
pred_mask_bgr_out = np.einsum('ijk->kij', pred_mask_bgr)
pred_mask_bgrs.append(pred_mask_bgr_out)
join_image(pred_mask_bgrs,splitlen,total_meta,output_path)
if __name__ == "__main__":
# 模型 config 配置文件
config_file = r'/root/MMSegmentation_Tutorials-main/20230816/mmsegmentation/work_dirs/GF2MultiDataset-Swintrans/GF2MultiDataset_Swintrans_20230904.py'
# 模型 checkpoint 权重文件
checkpoint_file = r'/root/MMSegmentation_Tutorials-main/20230816/mmsegmentation/work_dirs/GF2MultiDataset-Swintrans/best_mIoU_iter_27700.pth'
# 计算硬件
# device = 'cpu'
device = 'cuda:0'
model = init_model(config_file, checkpoint_file, device=device)
# 每个类别的 BGR 配色
palette = [
['background', [120,0,0]],
['crop', [127,127,127]],
['road', [0,200,0]],
['water', [144,238,144]],
]
palette_dict = {}
for idx, each in enumerate(palette):
palette_dict[idx] = each[1]
output_dir = r'/root/MMSegmentation_Tutorials-main/20230816/mmsegmentation/outputs/testset-pred_gf6'
if not os.path.exists(output_dir):
os.mkdir(output_dir)
PATH_IMAGE = r'/root/autodl-tmp/test'
for img_name in os.listdir(PATH_IMAGE):
output_path = os.path.join(output_dir,img_path)
img_path = os.path.join(PATH_IMAGE,img_name)
process_single_img(img_path,output_path)