深度学习大图切成小块图片代码---针对图像分割而言

1 根据图片大小和patch大小生成切图位置:

import glob
import os.path

import numpy as np
from skimage import io as sio
from params import * # 这个是我自己的一个超参数文件
from skimage import exposure


# 一般来说,stride = patch_size[0]//2
def gen_patch_pos(org_image, patch_size: list = [256, 256], stride: int = 128):
    imsize = org_image.shape
    range_y = np.arange(0, imsize[0] - patch_size[0], stride)
    range_x = np.arange(0, imsize[1] - patch_size[1], stride)
    if len(range_y) == 0:  # 防止图片的尺寸比原图还大
        range_y = np.append(range_y, 0)
    else:
        if range_y[-1] != imsize[0] - patch_size[0]:
            range_y = np.append(range_y, imsize[0] - patch_size[0])

    if len(range_x) == 0:
        range_x = np.append(range_x, 0)
    else:
        if range_x[-1] != imsize[1] - patch_size[1]:
            range_x = np.append(range_x, imsize[1] - patch_size[1])
    assert range_y is not None
    assert range_x is not None
    return range_y, range_x

2  根据上一步生成的位置开始切图:

# 根据位置数组将图片划分成子块
def patch_image_by_pos(image, range_y: np.ndarray, 
                       range_x: np.ndarray, patch_size:list=[256,256]):
    assert patch_size is not None
    sz = len(range_x) * len(range_y)  # 图像块的数量
    if len(image.shape) == 2:
        im_c = 1
    elif len(image.shape) == 3:
        im_c = 3
    assert im_c is not None
    if len(image.shape) == 2:
        res = np.zeros((sz, patch_size[0], patch_size[1]))
    if len(image.shape) == 3:
        res = np.zeros((sz, patch_size[0], patch_size[1], im_c))
    index = 0
    for y in range_y:
        for x in range_x:
            if im_c == 3:
                save_image = np.zeros((patch_size[0], patch_size[1], im_c))
            elif im_c == 1:
                save_image = np.zeros((patch_size[0], patch_size[1]))
            image_patch = image[y:y + patch_size[0], x:x + patch_size[1]]

            patch_ey = min(image_patch.shape[0], y + patch_size[0])
            patch_ex = min(image_patch.shape[1], x + patch_size[1])

            save_image[0:patch_ey, 0:patch_ex] = image_patch
            res[index] = save_image
            index += 1
    return res

上面两个代码参考了:https://zhuanlan.zhihu.com/p/39361808

3. 如果需要腐蚀膨胀预测,需要对图片的边缘进行零值填充,思路如下:

# predict膨胀预测:预测时,对每个(256,256)窗口,每次只保留中心(128,128)区域预测结果,每次滑窗步长为128,使预测结果不交叠;当然这个窗口大小和滑窗步长可以自定义,但是需注意的是步长需要为窗口大小的1/2

膨胀预测代码,参考自:

https://github.com/HuangQinJian/tianchi_CountyAgriculturalBrain_top1/blob/22ae5fe13ca5ec14bfbcef166313f130a29bc272/src/tools/inference.py#L32
# predict膨胀预测:预测时,对每个(256,256)窗口,每次只保留中心(128,128)区域预测结果,
因此两边需要各补上256/4=64,每次滑窗步长为128,使预测结果不交叠,思路如下:
'''
        0. 填充图像四个边界,使原图大小可以被滑动窗口整除;
        1. 再创造一个空白图像,将滑窗预测结果逐步填充至空白图像中;
        2. 膨胀预测:预测时,对每个(patch_size[0],patch_size[1])窗口,
           每次只保留中心(patch_size[0]/2,patch_size[1]/2)区域预测结果,
        每次滑窗步长为patch_size[1]/2,使预测结果不交叠;
'''

# image进行扩张到能被patch_size块整除的尺寸
def expand_image(image: np.ndarray, patch_size: [] = [256,256], 
                    stride: int = 128):
    im_h, im_w, im_c = None, None, None
    if len(image.shape) == 2:
        im_h, im_w = image.shape
    else:
        im_h, im_w, im_c = image.shape

    new_image = gen_zeros_image_expanding(image, patch_size, stride)
    assert new_image is not None
    field_h, field_w = patch_size[0] // 4, patch_size[1] // 4
    new_image[field_h:im_h + field_h, field_w:im_w + field_w] = image
    return new_image

# 生成扩张后的零值图像
def gen_zeros_image_expanding(image, patch_size: list, 
        stride: int):

    if len(image.shape) == 2:
        im_h, im_w = image.shape
    else:
        im_h, im_w, im_c = image.shape

    new_shape_h, new_shape_w = gen_new_img_shape(im_h, im_w, patch_size)

    new_image = None

    if len(image.shape) == 2:
        new_image = np.zeros((new_shape_h, new_shape_w), dtype=np.uint8)
    if len(image.shape) == 3:
        new_image = np.zeros((new_shape_h, new_shape_w, 3), dtype=np.uint8)
    return new_image


# 找到扩张后的图像的尺寸
def gen_new_img_shape(im_h: int, im_w: int, patch_size: []):
    n_h = find_near_divisible(im_h, patch_size[0])
    n_w = find_near_divisible(im_w, patch_size[1])
    new_shape_h = int(n_h) * patch_size[0]
    new_shape_w = int(n_w) * patch_size[1]

    field_h, field_w = patch_size[0] // 4, patch_size[1] // 4
    # 填充上下边界
    new_shape_h = field_h + new_shape_h + field_h
    # 填充左右边界
    new_shape_w = field_w + new_shape_w + field_w
    return new_shape_h, new_shape_w

# 找到分子能整除分母且比分子大的最小整数
def find_near_divisible(fen_zi: int, fen_mu: int):
    if fen_zi % fen_mu == 0:
        return fen_zi / fen_mu
    else:
        n = fen_zi // fen_mu + 1
        return n

膨胀预测之后,从膨胀图像中取出和原图一样的分割预测结果,参考自:

https://zhuanlan.zhihu.com/p/39361808以及https://github.com/HuangQinJian/tianchi_CountyAgriculturalBrain_top1/blob/22ae5fe13ca5ec14bfbcef166313f130a29bc272/src/tools/inference.py#L32
    field_h, field_w = patch_size[0] // 4, patch_size[1] // 4
    res = res[field_h:-field_h, field_w:-field_w]  # 去除整体外边界
    org_h, org_w = im_size[0], im_size[1]
    res = res[: org_h, : org_w]  # 去除补全patch_size[0]和patch_size[1]整数倍时的右下边界
    return res

你可能感兴趣的:(pytorch,计算机视觉与深度学习,深度学习)