一、使用paddleseg套件对遥感影像预测(基础)

一、使用paddleseg套件对遥感影像预测(基础)


目前paddleseg套件中的predict.py代码文件还不支持直接对遥感影像(大图)做预测,或者说把遥感大图直接丢进predict.py,它的预测效果非常差。

基于以上问题,本文结合paddleseg中predict.py源码和这篇博文代码(遥感语义分割切图预测之后再拼接)重新写了predict.py代码,希望可以帮助到使用飞桨框架做遥感影像语义分割的朋友,所以这里需要你会使用paddleseg套件或者对paddleseg源码有所了解,这里有位博主写了一系列有关paddleseg源码的文章,值得参考学习(人工智能研习社)。

重新写的predict.py代码主要分为四个部分
读取和裁剪遥感大图 网络模型推理预测小图块 拼接小图块预测结果 拼接结果写入文件

1、读取待预测遥感大图,将遥感大图裁剪成小图块,这里裁剪的小图块相邻之间不设置重叠度,小图块大小为256x256。
本部分代码如下:

#读取需要预测的遥感大图img_lists[local_rank][local_rank] = /home/aistudio/data/data70483/img.png
ori_image=cv2.imread(img_lists[local_rank][local_rank])
h_step = ori_image.shape[0] // 256  #高度步数
w_step = ori_image.shape[1] // 256  #宽度步数

h_rest = -(ori_image.shape[0] - 256 * h_step)  #剩余行数
w_rest = -(ori_image.shape[1] - 256 * w_step)  #剩余列数

seg_list = []  #小图块的列表
predict_list = []#预测小图块结果的列表
# 循环切图
for h in range(h_step):
    for w in range(w_step):
        # 划窗采样
        image_sample = ori_image[(h * 256):(h * 256 + 256),
                       (w * 256):(w * 256 + 256), :]
        seg_list.append(image_sample)
    seg_list.append(ori_image[(h * 256):(h * 256 + 256), -256:, :])
for w in range(w_step - 1):
    seg_list.append(ori_image[-256:, (w * 256):(w * 256 + 256), :])
seg_list.append(ori_image[-256:, -256:, :])

2、利用网络模型推理预测小图块,这里的代码改动不多,但是需要将img_lists[local_rank]参数改成存储小图块的列表seg_list,其他参数的设置根据需要而定,在这里本文只对小图块做最普通的推理预测,既不做多尺度预测、也不做滑窗预测(多尺度和滑窗预测是原本predict.py的功能,当然在这里我们也可以用)。
本部分代码如下:

progbar_pred = progbar.Progbar(target=len(seg_list), verbose=1)
with paddle.no_grad():
    for i, im in enumerate(seg_list):
        ori_shape = im.shape[:2]  #原始图片形状(h,w)
        im, _ = transforms(im)    #im.shape(3, 256, 256)    _为None
        im = im[np.newaxis, ...]   #im.shape(1,3,256,256)
        im = paddle.to_tensor(im)

        if False:
            pred = infer.aug_inference(
                model,
                im,
                ori_shape=ori_shape,
                transforms=transforms.transforms,
                scales=scales,
                flip_horizontal=flip_horizontal,
                flip_vertical=flip_vertical,
                is_slide=is_slide,
                stride=None,
                crop_size=None)
        else:
            pred = infer.inference(
                model,
                im,
                ori_shape=ori_shape,
                transforms=transforms.transforms,
                is_slide=False,
                stride=None,
                crop_size=None)
        pred = paddle.squeeze(pred)  #该OP会删除输入Tensor的Shape中尺寸为1的维度。查看pred的形状 应该剩下[h,w]
        pred = pred.numpy().astype('uint8')
        predict_list.append(pred)
        progbar_pred.update(i + 1)

3、将小图块的预测结果进行拼接,这里的拼接思想很简单,就是按照裁剪的顺序进行拼接。
本部分代码如下:

count_temp = 0
tmp = np.ones([ori_image.shape[0], ori_image.shape[1]])
for h in range(h_step):
    for w in range(w_step):
        tmp[
        h * 256:(h + 1) * 256,
        w * 256:(w + 1) * 256
        ] = predict_list[count_temp]
        count_temp += 1
    tmp[h * 256:(h + 1) * 256, w_rest:] = predict_list[count_temp][:, w_rest:]
    count_temp += 1
for w in range(w_step - 1):
    tmp[h_rest:, (w * 256):(w * 256 + 256)] = predict_list[count_temp][h_rest:, :]
    count_temp += 1
tmp[-257:-1, -257:-1] = predict_list[count_temp][:, :]

4、将拼接结果 tmp 写入图像文件中,这里使用了原先predict.py的写入函数,只是将函数中pred参数改成了tmp,需要注意的是一定要将tmp变量提前转换为uint8类型,不然程序会报错。
本部分代码如下:

tmp = tmp.astype('uint8')
# save added image
added_image = utils.visualize.visualize(args.image_path,tmp, weight=0.6)
added_image_path = os.path.join(added_saved_dir, im_file)
mkdir(added_image_path)
cv2.imwrite(added_image_path, added_image)
# save pseudo color prediction
pred_mask = utils.visualize.get_pseudo_color_map(tmp)
pred_saved_path = os.path.join(pred_saved_dir,
                                im_file.rsplit(".")[0] + ".png")
mkdir(pred_saved_path)
pred_mask.save(pred_saved_path)

到这里代码的主体部分基本上搞定了,值得注意的是paddleseg套件中predict.py会从paddleseg.core 调用predict.py,而本文为了方便移植代码,就将两个predict.py写成了一个predict.py。

当时写的第一个版本predict.py将裁剪的小图块尺寸设定为了256,同时将inference有些参数都设定死了,所以不推荐直接copy使用,仅作为参考学习。第一个版本predict.py完整代码如下:

import sys
import argparse
import os
import paddle
from paddleseg.cvlibs import manager, Config
from paddleseg.utils import get_sys_env, logger
import math
import cv2
import numpy as np
from paddleseg import utils
from paddleseg.core import infer
from paddleseg.utils import progbar

def mkdir(path):
    sub_dir = os.path.dirname(path)  #去掉文件名,返回目录
    if not os.path.exists(sub_dir):
        os.makedirs(sub_dir)

def partition_list(arr, m):
    """split the list 'arr' into m pieces"""
    n = int(math.ceil(len(arr) / float(m)))
    return [arr[i:i + n] for i in range(0, len(arr), n)]

def parse_args():
    parser = argparse.ArgumentParser(description='Model prediction')

    # params of prediction
    parser.add_argument(
        "--config", dest="cfg", help="The config file.", default=None, type=str)
    parser.add_argument(
        '--model_path',
        dest='model_path',
        help='The path of model for prediction',
        type=str,
        default=None)
    parser.add_argument(
        '--image_path',
        dest='image_path',
        help=
        'The path of image, it can be a file or a directory including images',
        type=str,
        default=None)
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='The directory for saving the predicted results',
        type=str,
        default='./output/result')

    # augment for prediction
    parser.add_argument(
        '--aug_pred',
        dest='aug_pred',
        help='Whether to use mulit-scales and flip augment for prediction',
        action='store_true')
    parser.add_argument(
        '--scales',
        dest='scales',
        nargs='+',
        help='Scales for augment',
        type=float,
        default=1.0)
    parser.add_argument(
        '--flip_horizontal',
        dest='flip_horizontal',
        help='Whether to use flip horizontally augment',
        action='store_true')
    parser.add_argument(
        '--flip_vertical',
        dest='flip_vertical',
        help='Whether to use flip vertically augment',
        action='store_true')

    # sliding window prediction
    parser.add_argument(
        '--is_slide',
        dest='is_slide',
        help='Whether to prediction by sliding window',
        action='store_true')
    parser.add_argument(
        '--crop_size',
        dest='crop_size',
        nargs=2,
        help=
        'The crop size of sliding window, the first is width and the second is height.',
        type=int,
        default=None)
    parser.add_argument(
        '--stride',
        dest='stride',
        nargs=2,
        help=
        'The stride of sliding window, the first is width and the second is height.',
        type=int,
        default=None)

    return parser.parse_args()

def get_image_list(image_path):
    """Get image list"""
    valid_suffix = [
        '.JPEG', '.jpeg', '.JPG', '.jpg', '.BMP', '.bmp', '.PNG', '.png' ,'.tif'
    ]
    image_list = []
    image_dir = None
    if os.path.isfile(image_path):
        if os.path.splitext(image_path)[-1] in valid_suffix:
            image_list.append(image_path)
    elif os.path.isdir(image_path):
        image_dir = image_path
        for root, dirs, files in os.walk(image_path):  #root=image_path
            for f in files:
                if os.path.splitext(f)[-1] in valid_suffix:
                    image_list.append(os.path.join(root, f))
    else:
        raise FileNotFoundError(
            '`--image_path` is not found. it should be an image file or a directory including images'
        )

    if len(image_list) == 0:
        raise RuntimeError('There are not image file in `--image_path`')

    return image_list, image_dir   #返回测试文件列表


def main(args):
    env_info = get_sys_env()
    place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[
        'GPUs used'] else 'cpu'

    paddle.set_device(place)
    if not args.cfg:
        raise RuntimeError('No configuration file specified.')

    cfg = Config(args.cfg)
    val_dataset = cfg.val_dataset   
    if not val_dataset:
        raise RuntimeError(
            'The verification dataset is not specified in the configuration file.'
        )

    msg = '\n---------------Config Information---------------\n'
    msg += str(cfg)
    msg += '------------------------------------------------'
    logger.info(msg)

    model = cfg.model
    transforms = val_dataset.transforms
    #image_list, image_dir = get_image_list('data/UAV_seg/images')   
    image_list, image_dir = get_image_list(args.image_path)#需要传入args.image_path参数   这个参数可以是测试图片的路径,也可以是单张图片的路径
    
    model_path=args.model_path,  #传入训练模型的路径
    save_dir=args.save_dir,
    aug_pred=False,
    scales=1.0,
    flip_horizontal=True,
    flip_vertical=False,
    is_slide=False,
    stride=None,
    crop_size=None

    para_state_dict = paddle.load(model_path[0])
    model.set_dict(para_state_dict)
    model.eval()
    nranks = paddle.distributed.get_world_size()
    local_rank = paddle.distributed.get_rank()
    if nranks > 1:
        img_lists = partition_list(image_list, nranks)
    else:
        img_lists = [image_list]  #列表的列表   img_lists[0]  ->列表

    added_saved_dir = os.path.join(save_dir[0], 'added_prediction')  #伪彩色和原图叠加
    pred_saved_dir = os.path.join(save_dir[0], 'pseudo_color_prediction')  #伪彩色预测结果

    logger.info("Start to predict...")
    

##############################  1、裁剪遥感大图  ########################
    #读取需要预测的遥感大图img_lists[local_rank][local_rank] = /home/aistudio/data/data70483/img.png
    ori_image=cv2.imread(img_lists[local_rank][local_rank])
    h_step = ori_image.shape[0] // 256  #高度步数
    w_step = ori_image.shape[1] // 256  #宽度步数

    h_rest = -(ori_image.shape[0] - 256 * h_step)  #剩余行数
    w_rest = -(ori_image.shape[1] - 256 * w_step)  #剩余列数

    seg_list = []  #由遥感大图裁剪成小图块的列表
    predict_list = []#预测小图块结果的列表
    # 循环切图
    for h in range(h_step):
        for w in range(w_step):
            # 划窗采样
            image_sample = ori_image[(h * 256):(h * 256 + 256),
                           (w * 256):(w * 256 + 256), :]
            seg_list.append(image_sample)
        seg_list.append(ori_image[(h * 256):(h * 256 + 256), -256:, :])
    for w in range(w_step - 1):
        seg_list.append(ori_image[-256:, (w * 256):(w * 256 + 256), :])
    seg_list.append(ori_image[-256:, -256:, :])
##############################裁剪结束########################

##############################  2、利用网络模型推理小图块  ########################
    progbar_pred = progbar.Progbar(target=len(seg_list), verbose=1)
    with paddle.no_grad():
        for i, im in enumerate(seg_list):
            ori_shape = im.shape[:2]  #原始图片形状(h,w)
            im, _ = transforms(im)    #im.shape(3, 256, 256)    _为None
            im = im[np.newaxis, ...]   #im.shape(1,3,256,256)
            im = paddle.to_tensor(im)

            if False:
                pred = infer.aug_inference(
                    model,
                    im,
                    ori_shape=ori_shape,
                    transforms=transforms.transforms,
                    scales=scales,
                    flip_horizontal=flip_horizontal,
                    flip_vertical=flip_vertical,
                    is_slide=is_slide,
                    stride=None,
                    crop_size=None)
            else:
                pred = infer.inference(
                    model,
                    im,
                    ori_shape=ori_shape,
                    transforms=transforms.transforms,
                    is_slide=False,
                    stride=None,
                    crop_size=None)
            pred = paddle.squeeze(pred)  #该OP会删除输入Tensor的Shape中尺寸为1的维度。查看pred的形状 应该剩下[h,w]
            pred = pred.numpy().astype('uint8')
            predict_list.append(pred)
            progbar_pred.update(i + 1)
##############################推理结束########################

############# 3、将预测后的图像块再拼接起来 ########################
    count_temp = 0
    tmp = np.ones([ori_image.shape[0], ori_image.shape[1]])
    for h in range(h_step):
        for w in range(w_step):
            tmp[
            h * 256:(h + 1) * 256,
            w * 256:(w + 1) * 256
            ] = predict_list[count_temp]
            count_temp += 1
        tmp[h * 256:(h + 1) * 256, w_rest:] = predict_list[count_temp][:, w_rest:]
        count_temp += 1
    for w in range(w_step - 1):
        tmp[h_rest:, (w * 256):(w * 256 + 256)] = predict_list[count_temp][h_rest:, :]
        count_temp += 1
    tmp[-257:-1, -257:-1] = predict_list[count_temp][:, :]
##################拼接结束########################

    #获取需要保存的图片名称,去掉前面的路径
    # get the saved name
    if image_dir is not None:
        pass
        #im_file = im_path.replace(image_dir, '')  #例:将PaddleSeg/data/optic_disc_seg/JPEGImages/P0011.jpg替换为/P0011.jpg
    else:
        im_file = os.path.basename(img_lists[local_rank][local_rank])  #带后缀名
    if im_file[0] == '/':  #去掉/
        im_file = im_file[1:]

#############
    tmp = tmp.astype('uint8')
    # save added image
    added_image = utils.visualize.visualize(args.image_path,tmp, weight=0.6)
    added_image_path = os.path.join(added_saved_dir, im_file)
    mkdir(added_image_path)
    cv2.imwrite(added_image_path, added_image)
    
    # save pseudo color prediction
    pred_mask = utils.visualize.get_pseudo_color_map(tmp)
    pred_saved_path = os.path.join(pred_saved_dir,
                                    im_file.rsplit(".")[0] + ".png")
    mkdir(pred_saved_path)
    pred_mask.save(pred_saved_path)

    # pred_im = utils.visualize(im_path, pred, weight=0.0)
    # pred_saved_path = os.path.join(pred_saved_dir, im_file)
    # mkdir(pred_saved_path)
    # cv2.imwrite(pred_saved_path, pred_im)
 
    #progbar_pred.update(i + 1)

if __name__ == '__main__':
    args = parse_args()
    main(args)

第二个版本的predict.py将“裁剪遥感大图”和“拼接小图块的预测结果”封装成了函数,分别为CropBigImage(ImagePath,CropScale)PinJie(predict_list , CropScale , ori_image , h_step , w_step , h_rest , w_rest) 。CropBigImage函数可以将遥感大图裁剪成任意尺寸的小图块,第二个版本的predict.py完整代码如下:

import sys
import argparse
import os
import paddle
from paddleseg.cvlibs import manager, Config
from paddleseg.utils import get_sys_env, logger
import math
import cv2
import numpy as np
from paddleseg import utils
from paddleseg.core import infer
from paddleseg.utils import progbar

def mkdir(path):
    sub_dir = os.path.dirname(path)  #去掉文件名,返回目录
    if not os.path.exists(sub_dir):
        os.makedirs(sub_dir)

def partition_list(arr, m):
    """split the list 'arr' into m pieces"""
    n = int(math.ceil(len(arr) / float(m)))
    return [arr[i:i + n] for i in range(0, len(arr), n)]

def parse_args():
    parser = argparse.ArgumentParser(description='Model prediction')

    # params of prediction
    parser.add_argument(
        "--config", dest="cfg", help="The config file.", default=None, type=str)
    parser.add_argument(
        '--model_path',
        dest='model_path',
        help='The path of model for prediction',
        type=str,
        default=None)
    parser.add_argument(
        '--image_path',
        dest='image_path',
        help=
        'The path of image, it can be a file or a directory including images',
        type=str,
        default=None)
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='The directory for saving the predicted results',
        type=str,
        default='./output/result')

    # augment for prediction
    parser.add_argument(
        '--aug_pred',
        dest='aug_pred',
        help='Whether to use mulit-scales and flip augment for prediction',
        action='store_true')
    parser.add_argument(
        '--scales',
        dest='scales',
        nargs='+',
        help='Scales for augment',
        type=float,
        default=1.0)
    parser.add_argument(
        '--flip_horizontal',
        dest='flip_horizontal',
        help='Whether to use flip horizontally augment',
        action='store_true')
    parser.add_argument(
        '--flip_vertical',
        dest='flip_vertical',
        help='Whether to use flip vertically augment',
        action='store_true')

    # sliding window prediction
    parser.add_argument(
        '--is_slide',
        dest='is_slide',
        help='Whether to prediction by sliding window',
        action='store_true')
    parser.add_argument(
        '--crop_size',
        dest='crop_size',
        nargs=2,
        help=
        'The crop size of sliding window, the first is width and the second is height.',
        type=int,
        default=None)
    parser.add_argument(
        '--stride',
        dest='stride',
        nargs=2,
        help=
        'The stride of sliding window, the first is width and the second is height.',
        type=int,
        default=None)

    return parser.parse_args()


def get_image_list(image_path):
    """Get image list"""
    valid_suffix = [
        '.JPEG', '.jpeg', '.JPG', '.jpg', '.BMP', '.bmp', '.PNG', '.png' ,'.tif'
    ]
    image_list = []
    image_dir = None
    if os.path.isfile(image_path):
        if os.path.splitext(image_path)[-1] in valid_suffix:
            image_list.append(image_path)
    elif os.path.isdir(image_path):
        image_dir = image_path
        for root, dirs, files in os.walk(image_path):  #root=image_path
            for f in files:
                if os.path.splitext(f)[-1] in valid_suffix:
                    image_list.append(os.path.join(root, f))
    else:
        raise FileNotFoundError(
            '`--image_path` is not found. it should be an image file or a directory including images'
        )

    if len(image_list) == 0:
        raise RuntimeError('There are not image file in `--image_path`')

    return image_list, image_dir   #返回测试文件列表

def CropBigImage(ImagePath,CropScale):

    ImagePath = ImagePath
    CropScale = CropScale
    seg_list = []#存储分割的图块
    ori_image=cv2.imread(ImagePath)##
    h_step = ori_image.shape[0] // CropScale
    w_step = ori_image.shape[1] // CropScale

    h_rest = -(ori_image.shape[0] - CropScale * h_step)
    w_rest = -(ori_image.shape[1] - CropScale * w_step)

    # 循环切图
    for h in range(h_step):
        for w in range(w_step):
            # 划窗采样
            image_sample = ori_image[(h * CropScale):(h * CropScale + CropScale),
                           (w * CropScale):(w * CropScale + CropScale), :]
            seg_list.append(image_sample)
        seg_list.append(ori_image[(h * CropScale):(h * CropScale + CropScale), -CropScale:, :])
    for w in range(w_step - 1):
        seg_list.append(ori_image[-CropScale:, (w * CropScale):(w * CropScale + CropScale), :])
    seg_list.append(ori_image[-CropScale:, -CropScale:, :])

    return seg_list , ori_image , h_step , w_step , h_rest , w_rest

def PinJie(predict_list , CropScale , ori_image , h_step , w_step , h_rest , w_rest):

    # 将预测后的图像块再拼接起来
    count_temp = 0
    tmp = np.ones([ori_image.shape[0], ori_image.shape[1]])
    for h in range(h_step):
        for w in range(w_step):
            tmp[
            h * CropScale:(h + 1) * CropScale,
            w * CropScale:(w + 1) * CropScale
            ] = predict_list[count_temp]
            count_temp += 1
        tmp[h * CropScale:(h + 1) * CropScale, w_rest:] = predict_list[count_temp][:, w_rest:]
        count_temp += 1
    for w in range(w_step - 1):
        tmp[h_rest:, (w * CropScale):(w * CropScale + CropScale)] = predict_list[count_temp][h_rest:, :]
        count_temp += 1
    tmp[-(CropScale+1):-1, -(CropScale+1):-1] = predict_list[count_temp][:, :]
    return tmp.astype('uint8')

def main(args):
    env_info = get_sys_env()
    place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[
        'GPUs used'] else 'cpu'

    paddle.set_device(place)
    if not args.cfg:
        raise RuntimeError('No configuration file specified.')

    cfg = Config(args.cfg)
    val_dataset = cfg.val_dataset   #用val_dataset?
    if not val_dataset:
        raise RuntimeError(
            'The verification dataset is not specified in the configuration file.'
        )

    msg = '\n---------------Config Information---------------\n'
    msg += str(cfg)
    msg += '------------------------------------------------'
    logger.info(msg)

    model = cfg.model
    transforms = val_dataset.transforms
    #image_list, image_dir = get_image_list('data/UAV_seg/images')   
    image_list, image_dir = get_image_list(args.image_path)#需要传入args.image_path参数   这个参数可以是测试图片的路径,也可以是单张图片的路径
    
    model_path=args.model_path  #传入训练模型的路径
    save_dir=args.save_dir
    aug_pred=args.aug_pred
    scales=args.scales
    flip_horizontal=args.flip_horizontal
    flip_vertical=args.flip_vertical
    is_slide=args.is_slide
    crop_size=args.crop_size
    stride=args.stride

    para_state_dict = paddle.load(model_path)
    model.set_dict(para_state_dict)
    model.eval()
    nranks = paddle.distributed.get_world_size()
    local_rank = paddle.distributed.get_rank()
    if nranks > 1:
        img_lists = partition_list(image_list, nranks)
    else:
        img_lists = [image_list]  #是列表还是列表的列表,等待测试   img_lists[0]  ->列表的列表

    added_saved_dir = os.path.join(save_dir, 'added_prediction')  #伪彩色和原图叠加
    pred_saved_dir = os.path.join(save_dir, 'pseudo_color_prediction')  #伪彩色预测结果
    
    #主要将遥感大图裁剪成固定尺寸的图块,生成图块列表
    ImagePath = img_lists[local_rank][local_rank]
    CropScale = 256
    seg_list , ori_image , h_step , w_step , h_rest , w_rest = CropBigImage(ImagePath,CropScale)

    predict_list = []
    progbar_pred = progbar.Progbar(target=len(seg_list), verbose=1)
    logger.info("Start to predict...")
    with paddle.no_grad():
        for i, im in enumerate(seg_list):
            ori_shape = im.shape[:2]  #原始图片形状(h,w)
            im, _ = transforms(im)    #im.shape(3, 512, 512)    _为None
            im = im[np.newaxis, ...]   #im.shape(1,3,512,512)
            im = paddle.to_tensor(im)

            if aug_pred:
                pred = infer.aug_inference(
                    model,
                    im,
                    ori_shape=ori_shape,
                    transforms=transforms.transforms,
                    scales=scales,
                    flip_horizontal=flip_horizontal,
                    flip_vertical=flip_vertical,
                    is_slide=is_slide,
                    stride=stride,
                    crop_size=crop_size)
            else:
                pred = infer.inference(
                    model,
                    im,
                    ori_shape=ori_shape,
                    transforms=transforms.transforms,
                    is_slide=is_slide,
                    stride=stride,
                    crop_size=crop_size)
            pred = paddle.squeeze(pred)  #该OP会删除输入Tensor的Shape中尺寸为1的维度。查看pred的形状 应该剩下[h,w]
            pred = pred.numpy().astype('uint8')
            predict_list.append(pred)
            progbar_pred.update(i + 1)

    #主要将图块的预测结果拼接成大图
    tmp = PinJie(predict_list , CropScale , ori_image , h_step , w_step , h_rest , w_rest)
    #############
    #获取需要保存的图片名称,去掉前面的路径
    # get the saved name
    if image_dir is not None:
        pass
        #im_file = im_path.replace(image_dir, '')  #例:将PaddleSeg/data/optic_disc_seg/JPEGImages/P0011.jpg替换为/P0011.jpg
    else:
        im_file = os.path.basename(img_lists[local_rank][local_rank])  #带后缀名
    if im_file[0] == '/':  #去掉/
        im_file = im_file[1:]

    # save added image
    added_image = utils.visualize.visualize(args.image_path,tmp, weight=0.6)
    added_image_path = os.path.join(added_saved_dir, im_file)
    mkdir(added_image_path)
    cv2.imwrite(added_image_path, added_image)
    
    # save pseudo color prediction
    pred_mask = utils.visualize.get_pseudo_color_map(tmp)
    pred_saved_path = os.path.join(pred_saved_dir,
                                    im_file.rsplit(".")[0] + ".png")
    mkdir(pred_saved_path)
    pred_mask.save(pred_saved_path)
    logger.info("-"*30+"END"+"-"*30)

if __name__ == '__main__':
    args = parse_args()
    main(args)

改写了predict.py源码文件,就要测试下它的效果,本文用了一张无人机遥感影像,目的是作物分类。如下图:
一、使用paddleseg套件对遥感影像预测(基础)_第1张图片
为了减轻边缘效应和拼接痕迹,这里使用重叠度为50%的裁剪方式将原图裁剪成7000多张256x256的数据集,利用Unet网络对数据集进行训练,利用本文改写的predict.py对原图进行预测。运行predict.py代码参考如下:!python predict.py --config unet-uav.yml --model_path output/best_model/model.pdparams --image_path /home/aistudio/data/data70483/img.png,其中 –config–model_path–image_path 都是需要传入的参数,有这些参数但不仅限这些参数。预测结果图如下:
一、使用paddleseg套件对遥感影像预测(基础)_第2张图片
一、使用paddleseg套件对遥感影像预测(基础)_第3张图片
从语义分割的结果来看,感觉还不错,不过这里我的训练集和测试集是同一个数据集,所以并不能说明网络模型的泛化能力,只能说明网络模型的拟合能力还可以,但是本文目的已经达到了,就是对遥感影像(大图)预测。

各位小伙伴有任何问题可以在评论中留言,下一篇博文的内容依然是使用paddleseg套件对遥感影像预测,不过下篇博文的方法和以上代码有所差别,主要是做有重叠度裁剪待预测遥感大图和忽略相邻图块重叠部分做拼接,目的是为了减轻边缘效应和拼接痕迹,这对语义分割来说十分重要。

你可能感兴趣的:(paddlepaddle,计算机视觉,图像识别)