语义分割 patches 训练数据制作

patches 切割

在制作训练数据集,或使用训练好的模型对大尺寸图像进行预测时,需要将图像进行切割成 patches

patches 的切割可以分为:

  • 离线切割,将 切割的 patches 保存至本地
  • 在线切割, 使用滑动窗口的方式取 patches

训练数据切片制作大图像测试时的切割有稍微的区别

  • 训练数据切片时,不需要对图像进行 padding(填充),不需要将切割后的 patches 拼接复原
  • 在对大图像进行预测时,如果需要进行 腐蚀膨胀预测 时,每次只保留 patches 的中心部分,使预测结果不相交,因此需要对大图像边缘进行 padding(零值填充),并 拼接复原

训练 patches 数据制作

  1. 将大图像按一定步长切割成 patches,

  2. 剔除目标像素较少以及只有背景的 patches

    Note: 不考虑目标在patches 中的位置,如目标检测 任务中时可能需要将目标 bbox 落在 patches 中心;

切割思路

语义分割 patches 训练数据制作_第1张图片

原图像为 ( H, W, C )
从原图像左上角开始切割,每个 patches 的左上角索引为 (x, y), (x+step, y), ··· ···
知道了每个 patches 的各个角的索引,通过索引 则可以对 大图像进行 索引切片 获得。
如果 patches 的右下索引 超过了 图像的 边缘,即 W,或H . 则将 切片右下角索引向左移动,使其等于边缘值,
具体代码如下

## 以 步长为 30 将图像和标签掩膜切成 384*384 大小的切片
## 并剔除裂缝标注像素很少,以及没有裂缝标注的切片

import cv2
from skimage import io
import numpy as np


# 图像和标签读取函数
def read_data(file, is_label = False):
    if is_label:
    	img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
    else:
    	img = cv2.imread(file,1)
       # print(img.shape)
    	img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
	return img

def read_by_gdal(file):
	dataset = gdal.Open(file)
    im_array = dataset.ReadAsArray()
    if len(im_array.shape) == 3:
    	im_array = im_array.transpose((1,2,0))
    return im_array

def save_data(save_path, img):
    img = img.astype(np.uint8)
    io.imsave(save_path, img)


# 裁成 384*384 的切片, 裁剪步长 30
# 需要先对图片进行 padding
# 切片数量 = ( H/(h-step)  # H:待切割大图像,切片宽

def image2patches(image,patch_size,step):
"""
image: 待切割的大图像
patch_size: patch 大小, 如(256, 256)
step: 切割步长, 如 128
"""
# patch_size 不能大于 image
# step 要大于 0 
    assert patch_size[0] < image.shape[0]
    assert patch_size[1] < image.shape[1] 
    assert step > 0

    if len(image.shape) == 2:
    	im_high, im_width = image.shape
    if len(image.shape) == 3:
   		im_high, im_width, im_channel = image.shape

    ## 构建图像块索引
    range_y = np.arange(0,im_high-patch_size[0], step) 
    range_x = np.arange(0, im_width - patch_size[1], step)# img_high - patch_size[0] 表示索引不会超过 img_high - patch_szie[0]位置
# 这样,判断最后一个索引是否等于 img_high - patch_szie[0]
# 如果最后一个索引刚好等于 img_high - patch_szie[0],
# 则可以知道这最后一个 patch 是否刚好能到边缘。
# 否则,需要再添加一个索引将剩下不够一个 patch_size 的部分切出来
    if range_y[-1] != im_high - patch_size[0]:
    	np.append(range_y,im_high - patch_size[0])
    if range_y[-1] != im_width - patch_size[1]:
    	np.append(range_x,im_high - patch_size[0])

    # 图像块的数量
    sz = len(range_x) * len(range_y)

    if len(image.shape) == 2:
       ## 初始化灰度图
    	res = np.zeros((sz, patch_size[0], patch_size[1]))
    if len(image.shape) == 3:
       ## 生成全 0 的 numpy array 存放原图的 patch
       ## 初始化 RGB 图像
    	res = np.zeros((sz,patch_size[0],patch_size[1],im_channel))

    index = 0
    for y in range_y:
    	for x in range_x:
        	patch = image[y:y+patch_size[0], x:x+patch_size[1]]
           # 当图像太大时,一次性将切片存储为大图像会导致内存不够,因此使用生成器,或取一张切片返回一张
        	res[index] = patch
        	yield res[index]
        	index += 1

    # print(res.shape)
    # return res


# 剔除目标像素占比较少和只有背景的 patches
# 根据 标签 patch 计算目标像素
# 只针对类别数量为 两类 的情况

def remove_usless_patch(img_patch, lbl_patch, img_save_path, lbl_save_path):
    """
    img_patch: iamge patch
    lal_patch: label_patch
    img_save_path, lbl_save_path: 图片和标签保存地址
    image patch 和 label_patch 是对应的  
    """

    # 计算 patch 中每类的像素个数
    # clss: array, 类别; counts: array, 计数 
    clss, counts = np.unique(lbl_patch, return_counts=True)
    # note: 如果全为一个类,则上面的 counts 只有一个元素
    try:
        target = counts[1]
        # 如果目标像素大于30个,则保存至本地
        if target > 30:
            save_data(img_save_path,img_patch)
            save_data(lbl_save_path,lbl_patch)
    except Exception as e:
        print(e)

    #总像素
    #overall = lbl_patch.size


## 对 patch 使用模型预测后,需要进行拼接。
## 在对大图像进行切割时,由于内存原因,需要使用到生成器

def patches2image(patches, imsize, step):
    """
    patches: 使用 image2patches 得到的数据
    imsize: 原始图像的 宽和高, 如(771,841)
    step: 生成 patch 时的步长
    """
    patch_size = patches.shape[1:3]
    if len(patches.shape) == 3:
        ## 灰度图像
        res = np.zeros((imsize[0],imsize[1]))
        w = np.zeros((imsize[0],imsize[1]))
    if len(patches.shape) == 4:
        ## rgb
        res = np.zeros((imsize[0],imsize[1],3))
        w = np.zeros((imsize[0],imsize[1],3))

    ## 与切片函数中相同,获得切片索引
    range_y = np.arange(0, imsize[0] - patch_size[0], step)
    range_x = np.arange(0, imsize[1] - patch_size[1], step)
    if range_y[-1] != imsize[0] - patch_size[0]:
        range_y = np.append(range_y, imsize[0] - patch_size[0])
    if range_y[-1] != imsize[1] - patch_size[1]:
        range_y = np.append(range_y, imsize[1] - patch_size[1])
    
    index = 0 
    for y in range_y:
        for x in range_x:
            res[y:y+patch_size[0], x:x+patch_size[1]] = res[y:]

if __name__ == "__main__":
import matplotlib.pyplot as plt
import os
from osgeo import gdal


im_path = r"E:\chenximing\Cracks\GIS\crack_001.tif"
label_path = r"E:\chenximing\Cracks\GoafCrack\Dataset_Production\Raw_data\Arcgis_label_production_tmp\label_001_.tif"

im = read_by_gdal(im_path)
label = read_by_gdal(label_path)
print(im.shape,im.dtype)
print(label.shape,label.dtype )
im2patch = image2patches(image=im,patch_size=(384,384),step=40)
label2patch = image2patches(image = label, patch_size=(384,384), step=40)

# print(im2patch[0].shape)

img_save_dir = r'./Raw_data/image'
label_save_dir = r'./Raw_data/label'
patch_base_name = 'shenyuan'

img_suffix = '.png'
label_suffix= '.bmp'

# if len(im.shape) == 3:
#     sufix = '.png'
# if len(im.shape) == 2:
#     sufix = '.bmp'

# for i in np.arange(im2patch.shape[0]):
#     print(i,im2patch[i].shape)

#     # 保存切片
#     patch_number_ = str(i)
#     patch_name = patch_base_name + "_" + patch_number_ + sufix

#     save_path = os.path.join(save_dir, patch_name)
#     print(save_path)
#     save_data(save_path,im2patch[i])


for n, patches in enumerate(zip(im2patch,label2patch)):
   print(n, patches[0].shape,np.unique(patches[0]))
   print(n, patches[1].shape,np.unique(patches[1]))
   # 保存切片
   patch_number_ = str(n)
   img_patch_name = patch_base_name + "_" + patch_number_ + img_suffix
   label_patch_name = patch_base_name + "_" + patch_number_ + label_suffix

   img_save_path = os.path.join(img_save_dir, img_patch_name)
   label_save_path = os.path.join(label_save_dir, label_patch_name)

   print(img_save_path,label_save_path)
   im_patch = patches[0]
   label_patch = patches[1]

   remove_usless_patch(im_patch, label_patch, img_save_path,label_save_path)








# for ii in np.arange(im2patch.shape[0]):
#     plt.subplot(10,10,ii+1)
#     plt.imshow(im2patch[ii])
#     plt.axis("off")
# plt.subplots_adjust(wspace = 0.05,hspace = 0.05)  
# plt.show()
# plt.imshow(im2patch[5])
# plt.show()

参考

https://blog.csdn.net/qq_15054345/article/details/119902974

https://www.cnblogs.com/walter-xh/p/11755550.html

https://blog.csdn.net/zengwubbb/article/details/115800477

tianchi_CountyAgriculturalBrain_top1/inference.py at 22ae5fe13ca5ec14bfbcef166313f130a29bc272 · HuangQinJian/tianchi_CountyAgriculturalBrain_top1 · GitHub

将图像切分为图像块,并复原 - 知乎 (zhihu.com)

你可能感兴趣的:(python,计算机视觉,深度学习)