1 根据图片大小和patch大小生成切图位置:
import glob
import os.path
import numpy as np
from skimage import io as sio
from params import * # 这个是我自己的一个超参数文件
from skimage import exposure
# 一般来说,stride = patch_size[0]//2
def gen_patch_pos(org_image, patch_size: list = [256, 256], stride: int = 128):
imsize = org_image.shape
range_y = np.arange(0, imsize[0] - patch_size[0], stride)
range_x = np.arange(0, imsize[1] - patch_size[1], stride)
if len(range_y) == 0: # 防止图片的尺寸比原图还大
range_y = np.append(range_y, 0)
else:
if range_y[-1] != imsize[0] - patch_size[0]:
range_y = np.append(range_y, imsize[0] - patch_size[0])
if len(range_x) == 0:
range_x = np.append(range_x, 0)
else:
if range_x[-1] != imsize[1] - patch_size[1]:
range_x = np.append(range_x, imsize[1] - patch_size[1])
assert range_y is not None
assert range_x is not None
return range_y, range_x
2 根据上一步生成的位置开始切图:
# 根据位置数组将图片划分成子块
def patch_image_by_pos(image, range_y: np.ndarray,
range_x: np.ndarray, patch_size:list=[256,256]):
assert patch_size is not None
sz = len(range_x) * len(range_y) # 图像块的数量
if len(image.shape) == 2:
im_c = 1
elif len(image.shape) == 3:
im_c = 3
assert im_c is not None
if len(image.shape) == 2:
res = np.zeros((sz, patch_size[0], patch_size[1]))
if len(image.shape) == 3:
res = np.zeros((sz, patch_size[0], patch_size[1], im_c))
index = 0
for y in range_y:
for x in range_x:
if im_c == 3:
save_image = np.zeros((patch_size[0], patch_size[1], im_c))
elif im_c == 1:
save_image = np.zeros((patch_size[0], patch_size[1]))
image_patch = image[y:y + patch_size[0], x:x + patch_size[1]]
patch_ey = min(image_patch.shape[0], y + patch_size[0])
patch_ex = min(image_patch.shape[1], x + patch_size[1])
save_image[0:patch_ey, 0:patch_ex] = image_patch
res[index] = save_image
index += 1
return res
上面两个代码参考了:https://zhuanlan.zhihu.com/p/39361808
3. 如果需要腐蚀膨胀预测,需要对图片的边缘进行零值填充,思路如下:
# predict膨胀预测:预测时,对每个(256,256)窗口,每次只保留中心(128,128)区域预测结果,每次滑窗步长为128,使预测结果不交叠;当然这个窗口大小和滑窗步长可以自定义,但是需注意的是步长需要为窗口大小的1/2
膨胀预测代码,参考自:
https://github.com/HuangQinJian/tianchi_CountyAgriculturalBrain_top1/blob/22ae5fe13ca5ec14bfbcef166313f130a29bc272/src/tools/inference.py#L32
# predict膨胀预测:预测时,对每个(256,256)窗口,每次只保留中心(128,128)区域预测结果,
因此两边需要各补上256/4=64,每次滑窗步长为128,使预测结果不交叠,思路如下:
'''
0. 填充图像四个边界,使原图大小可以被滑动窗口整除;
1. 再创造一个空白图像,将滑窗预测结果逐步填充至空白图像中;
2. 膨胀预测:预测时,对每个(patch_size[0],patch_size[1])窗口,
每次只保留中心(patch_size[0]/2,patch_size[1]/2)区域预测结果,
每次滑窗步长为patch_size[1]/2,使预测结果不交叠;
'''
# image进行扩张到能被patch_size块整除的尺寸
def expand_image(image: np.ndarray, patch_size: [] = [256,256],
stride: int = 128):
im_h, im_w, im_c = None, None, None
if len(image.shape) == 2:
im_h, im_w = image.shape
else:
im_h, im_w, im_c = image.shape
new_image = gen_zeros_image_expanding(image, patch_size, stride)
assert new_image is not None
field_h, field_w = patch_size[0] // 4, patch_size[1] // 4
new_image[field_h:im_h + field_h, field_w:im_w + field_w] = image
return new_image
# 生成扩张后的零值图像
def gen_zeros_image_expanding(image, patch_size: list,
stride: int):
if len(image.shape) == 2:
im_h, im_w = image.shape
else:
im_h, im_w, im_c = image.shape
new_shape_h, new_shape_w = gen_new_img_shape(im_h, im_w, patch_size)
new_image = None
if len(image.shape) == 2:
new_image = np.zeros((new_shape_h, new_shape_w), dtype=np.uint8)
if len(image.shape) == 3:
new_image = np.zeros((new_shape_h, new_shape_w, 3), dtype=np.uint8)
return new_image
# 找到扩张后的图像的尺寸
def gen_new_img_shape(im_h: int, im_w: int, patch_size: []):
n_h = find_near_divisible(im_h, patch_size[0])
n_w = find_near_divisible(im_w, patch_size[1])
new_shape_h = int(n_h) * patch_size[0]
new_shape_w = int(n_w) * patch_size[1]
field_h, field_w = patch_size[0] // 4, patch_size[1] // 4
# 填充上下边界
new_shape_h = field_h + new_shape_h + field_h
# 填充左右边界
new_shape_w = field_w + new_shape_w + field_w
return new_shape_h, new_shape_w
# 找到分子能整除分母且比分子大的最小整数
def find_near_divisible(fen_zi: int, fen_mu: int):
if fen_zi % fen_mu == 0:
return fen_zi / fen_mu
else:
n = fen_zi // fen_mu + 1
return n
膨胀预测之后,从膨胀图像中取出和原图一样的分割预测结果,参考自:
https://zhuanlan.zhihu.com/p/39361808以及https://github.com/HuangQinJian/tianchi_CountyAgriculturalBrain_top1/blob/22ae5fe13ca5ec14bfbcef166313f130a29bc272/src/tools/inference.py#L32
field_h, field_w = patch_size[0] // 4, patch_size[1] // 4
res = res[field_h:-field_h, field_w:-field_w] # 去除整体外边界
org_h, org_w = im_size[0], im_size[1]
res = res[: org_h, : org_w] # 去除补全patch_size[0]和patch_size[1]整数倍时的右下边界
return res