在制作训练数据集,或使用训练好的模型对大尺寸图像进行预测时,需要将图像进行切割成 patches
patches 的切割可以分为:
- 离线切割,将 切割的 patches 保存至本地
- 在线切割, 使用滑动窗口的方式取 patches
在训练数据切片制作和大图像测试时的切割有稍微的区别
- 训练数据切片时,不需要对图像进行 padding(填充),不需要将切割后的 patches 拼接复原
- 在对大图像进行预测时,如果需要进行 腐蚀膨胀预测 时,每次只保留 patches 的中心部分,使预测结果不相交,因此需要对大图像边缘进行 padding(零值填充),并 拼接复原
将大图像按一定步长切割成 patches,
剔除目标像素较少以及只有背景的 patches
Note: 不考虑目标在patches 中的位置,如目标检测 任务中时可能需要将目标 bbox 落在 patches 中心;
切割思路
原图像为 ( H, W, C )
从原图像左上角开始切割,每个 patches 的左上角索引为 (x, y), (x+step, y), ··· ···
知道了每个 patches 的各个角的索引,通过索引 则可以对 大图像进行 索引切片 获得。
如果 patches 的右下索引 超过了 图像的 边缘,即 W,或H . 则将 切片右下角索引向左移动,使其等于边缘值,
具体代码如下
## 以 步长为 30 将图像和标签掩膜切成 384*384 大小的切片
## 并剔除裂缝标注像素很少,以及没有裂缝标注的切片
import cv2
from skimage import io
import numpy as np
# 图像和标签读取函数
def read_data(file, is_label = False):
if is_label:
img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
else:
img = cv2.imread(file,1)
# print(img.shape)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def read_by_gdal(file):
dataset = gdal.Open(file)
im_array = dataset.ReadAsArray()
if len(im_array.shape) == 3:
im_array = im_array.transpose((1,2,0))
return im_array
def save_data(save_path, img):
img = img.astype(np.uint8)
io.imsave(save_path, img)
# 裁成 384*384 的切片, 裁剪步长 30
# 需要先对图片进行 padding
# 切片数量 = ( H/(h-step) # H:待切割大图像,切片宽
def image2patches(image,patch_size,step):
"""
image: 待切割的大图像
patch_size: patch 大小, 如(256, 256)
step: 切割步长, 如 128
"""
# patch_size 不能大于 image
# step 要大于 0
assert patch_size[0] < image.shape[0]
assert patch_size[1] < image.shape[1]
assert step > 0
if len(image.shape) == 2:
im_high, im_width = image.shape
if len(image.shape) == 3:
im_high, im_width, im_channel = image.shape
## 构建图像块索引
range_y = np.arange(0,im_high-patch_size[0], step)
range_x = np.arange(0, im_width - patch_size[1], step)# img_high - patch_size[0] 表示索引不会超过 img_high - patch_szie[0]位置
# 这样,判断最后一个索引是否等于 img_high - patch_szie[0]
# 如果最后一个索引刚好等于 img_high - patch_szie[0],
# 则可以知道这最后一个 patch 是否刚好能到边缘。
# 否则,需要再添加一个索引将剩下不够一个 patch_size 的部分切出来
if range_y[-1] != im_high - patch_size[0]:
np.append(range_y,im_high - patch_size[0])
if range_y[-1] != im_width - patch_size[1]:
np.append(range_x,im_high - patch_size[0])
# 图像块的数量
sz = len(range_x) * len(range_y)
if len(image.shape) == 2:
## 初始化灰度图
res = np.zeros((sz, patch_size[0], patch_size[1]))
if len(image.shape) == 3:
## 生成全 0 的 numpy array 存放原图的 patch
## 初始化 RGB 图像
res = np.zeros((sz,patch_size[0],patch_size[1],im_channel))
index = 0
for y in range_y:
for x in range_x:
patch = image[y:y+patch_size[0], x:x+patch_size[1]]
# 当图像太大时,一次性将切片存储为大图像会导致内存不够,因此使用生成器,或取一张切片返回一张
res[index] = patch
yield res[index]
index += 1
# print(res.shape)
# return res
# 剔除目标像素占比较少和只有背景的 patches
# 根据 标签 patch 计算目标像素
# 只针对类别数量为 两类 的情况
def remove_usless_patch(img_patch, lbl_patch, img_save_path, lbl_save_path):
"""
img_patch: iamge patch
lal_patch: label_patch
img_save_path, lbl_save_path: 图片和标签保存地址
image patch 和 label_patch 是对应的
"""
# 计算 patch 中每类的像素个数
# clss: array, 类别; counts: array, 计数
clss, counts = np.unique(lbl_patch, return_counts=True)
# note: 如果全为一个类,则上面的 counts 只有一个元素
try:
target = counts[1]
# 如果目标像素大于30个,则保存至本地
if target > 30:
save_data(img_save_path,img_patch)
save_data(lbl_save_path,lbl_patch)
except Exception as e:
print(e)
#总像素
#overall = lbl_patch.size
## 对 patch 使用模型预测后,需要进行拼接。
## 在对大图像进行切割时,由于内存原因,需要使用到生成器
def patches2image(patches, imsize, step):
"""
patches: 使用 image2patches 得到的数据
imsize: 原始图像的 宽和高, 如(771,841)
step: 生成 patch 时的步长
"""
patch_size = patches.shape[1:3]
if len(patches.shape) == 3:
## 灰度图像
res = np.zeros((imsize[0],imsize[1]))
w = np.zeros((imsize[0],imsize[1]))
if len(patches.shape) == 4:
## rgb
res = np.zeros((imsize[0],imsize[1],3))
w = np.zeros((imsize[0],imsize[1],3))
## 与切片函数中相同,获得切片索引
range_y = np.arange(0, imsize[0] - patch_size[0], step)
range_x = np.arange(0, imsize[1] - patch_size[1], step)
if range_y[-1] != imsize[0] - patch_size[0]:
range_y = np.append(range_y, imsize[0] - patch_size[0])
if range_y[-1] != imsize[1] - patch_size[1]:
range_y = np.append(range_y, imsize[1] - patch_size[1])
index = 0
for y in range_y:
for x in range_x:
res[y:y+patch_size[0], x:x+patch_size[1]] = res[y:]
if __name__ == "__main__":
import matplotlib.pyplot as plt
import os
from osgeo import gdal
im_path = r"E:\chenximing\Cracks\GIS\crack_001.tif"
label_path = r"E:\chenximing\Cracks\GoafCrack\Dataset_Production\Raw_data\Arcgis_label_production_tmp\label_001_.tif"
im = read_by_gdal(im_path)
label = read_by_gdal(label_path)
print(im.shape,im.dtype)
print(label.shape,label.dtype )
im2patch = image2patches(image=im,patch_size=(384,384),step=40)
label2patch = image2patches(image = label, patch_size=(384,384), step=40)
# print(im2patch[0].shape)
img_save_dir = r'./Raw_data/image'
label_save_dir = r'./Raw_data/label'
patch_base_name = 'shenyuan'
img_suffix = '.png'
label_suffix= '.bmp'
# if len(im.shape) == 3:
# sufix = '.png'
# if len(im.shape) == 2:
# sufix = '.bmp'
# for i in np.arange(im2patch.shape[0]):
# print(i,im2patch[i].shape)
# # 保存切片
# patch_number_ = str(i)
# patch_name = patch_base_name + "_" + patch_number_ + sufix
# save_path = os.path.join(save_dir, patch_name)
# print(save_path)
# save_data(save_path,im2patch[i])
for n, patches in enumerate(zip(im2patch,label2patch)):
print(n, patches[0].shape,np.unique(patches[0]))
print(n, patches[1].shape,np.unique(patches[1]))
# 保存切片
patch_number_ = str(n)
img_patch_name = patch_base_name + "_" + patch_number_ + img_suffix
label_patch_name = patch_base_name + "_" + patch_number_ + label_suffix
img_save_path = os.path.join(img_save_dir, img_patch_name)
label_save_path = os.path.join(label_save_dir, label_patch_name)
print(img_save_path,label_save_path)
im_patch = patches[0]
label_patch = patches[1]
remove_usless_patch(im_patch, label_patch, img_save_path,label_save_path)
# for ii in np.arange(im2patch.shape[0]):
# plt.subplot(10,10,ii+1)
# plt.imshow(im2patch[ii])
# plt.axis("off")
# plt.subplots_adjust(wspace = 0.05,hspace = 0.05)
# plt.show()
# plt.imshow(im2patch[5])
# plt.show()
https://blog.csdn.net/qq_15054345/article/details/119902974
https://www.cnblogs.com/walter-xh/p/11755550.html
https://blog.csdn.net/zengwubbb/article/details/115800477
tianchi_CountyAgriculturalBrain_top1/inference.py at 22ae5fe13ca5ec14bfbcef166313f130a29bc272 · HuangQinJian/tianchi_CountyAgriculturalBrain_top1 · GitHub
将图像切分为图像块,并复原 - 知乎 (zhihu.com)