没有overlap的高分辨率带标签图像切片——工作总结

没有overlap的高分辨率带标签图像切片

  • 要求
  • 算法
  • 结果
  • 结论

要求

1.切片之间没有重叠区域,均分为4份或9份,带标签切割。
2.将检测后的切片的检测框移到原图上。

算法

直接看代码

#定义一个函数slice,将输入图像按照给定的裁剪个数进行切片并保存到指定文件夹,没有overlap

import cv2
import numpy as np
import os
import xml.etree.ElementTree as ET
import math
import codecs

def get(root, name):
    vars = root.findall(name)
    return vars

def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise NotImplementedError('Can not find %s in %s.' % (name, root.tag))
    if length > 0 and len(vars) != length:
        raise NotImplementedError('The size of %s is supposed to be %d, but is %d.' % (name, length, len(vars)))
    if length == 1:
        vars = vars[0]
    return vars

def deal_xml(xml_f):
    tree = ET.parse(xml_f) #获取元素树
    root = tree.getroot() #获取根节点,即annotation
    object_list = []
    # 处理每个标注的检测框
    for obj in get(root, 'object'):  #遍历元素树中tag名称为object的标签
        # 取出检测框类别名称
        category = get_and_check(obj, 'name', 1).text
        # 更新类别ID字典
        bndbox = get_and_check(obj, 'bndbox', 1)
        xmin = int(get_and_check(bndbox, 'xmin', 1).text)
        ymin = int(get_and_check(bndbox, 'ymin', 1).text)
        xmax = int(get_and_check(bndbox, 'xmax', 1).text)
        ymax = int(get_and_check(bndbox, 'ymax', 1).text)
        assert (xmax > xmin)
        assert (ymax > ymin)
        obj_info = [xmin, ymin, xmax, ymax, category]
        object_list.append(obj_info)
    return object_list

def exist_objs(slice_im_name, list_2):
    '''
    slice_im_name:切片图像名称
    list_2:原图中的所有目标
    return:原图中位于当前slice中的目标集合,且坐标直接变换为切片中的坐标
    '''
    return_objs = []
    o_name,_ = os.path.splitext(slice_im_name)
    ss_xmin, ss_ymin, ss_xmax, ss_ymax = o_name.split('_')[1:5]
    s_xmin = int(ss_xmin)
    s_ymin = int(ss_ymin)
    s_xmax = int(ss_xmax)
    s_ymax = int(ss_ymax)
    for obj in list_2:
        xmin, ymin, xmax, ymax, category = obj[0], obj[1], obj[2], obj[3], obj[4]
        #第一种:标签在切片内
        if s_xmin <= xmin <= s_xmax and s_ymin <= ymin <= s_ymax:  # 目标点的左上角在切图区域中
            if s_xmin <= xmax <= s_xmax and s_ymin <= ymax <= s_ymax:  # 目标点的右下角在切图区域中
                x_new = xmin - s_xmin
                y_new = ymin - s_ymin
                return_objs.append([x_new, y_new, x_new + (xmax - xmin), y_new + (ymax - ymin), category])
        if s_xmin <= xmin <= s_xmax and ymin < s_ymin:
            #第二种:标签在切片正上方
            if s_xmin <= xmax <= s_xmax and s_ymin <= ymax <= s_ymax:
                x_new = xmin - s_xmin
                y_new = 0
                return_objs.append([x_new, y_new, xmax - s_ymax, ymax - s_ymax, category])
            #第三种:标签在切片右上方
            if xmax > s_xmax and s_ymin <= ymax <= s_ymax:
                x_new = xmin - s_xmin
                y_new = 0
                return_objs.append([x_new, y_new, s_xmax - s_xmin, ymax - s_ymin, category])
        if s_ymin <= ymin <= s_ymax and xmin < s_xmin:
            #第四种:标签在切片正左方
            if s_xmin < xmax <= s_xmax and s_ymin < ymax <= s_ymax:
                x_new = 0
                y_new = ymin - s_ymin
                return_objs.append([x_new, y_new, xmax - s_xmin, ymax - s_ymin, category])
            #第五种:标签在切片左下方
            if s_xmin < xmax < s_xmax and ymax >= s_ymax:
                x_new = 0
                y_new = ymin - s_ymin
                return_objs.append([x_new, y_new, xmax - s_xmin, s_ymax - s_ymin, category])
        #第六种:标签在切片左上方
        if s_xmin > xmin and ymin < s_ymin:
            if s_xmin <= xmax <= s_xmax and s_ymin <= ymax <= s_ymax:
                x_new = 0
                y_new = 0
                return_objs.append([x_new, y_new, xmax - s_xmin, ymax - s_ymin, category])
        if s_xmin <= xmin <= s_xmax and s_ymin <= ymin <= s_ymax:
            #第七种:标签在切片右下方
            if ymax > s_ymax and xmax > s_xmax:
                x_new = xmin - s_xmin
                y_new = ymin - s_ymin
                return_objs.append([x_new, y_new, s_xmax - s_xmin, s_ymax - s_ymin, category])
            #第八种:标签在切片下方
            if s_xmin <= xmax <= s_xmax and ymax > s_ymax:
                x_new = xmin - s_xmin
                y_new = ymin - s_ymin
                return_objs.append([x_new, y_new, xmax - s_xmin, s_ymax - s_ymin, category])
            #第九种:标签在切片正右方
            if xmax > s_xmax and s_ymin <= ymax <= s_ymax:
                x_new = xmin - s_xmin
                y_new = ymin - s_ymin
                return_objs.append([x_new, y_new, s_xmax - s_xmin, ymax - s_ymin, category])
    return return_objs

def make_voc(img_path,savedir,exist_obj_list): #制作切片的xml文件
    if not os.path.exists(savedir):
        os.makedirs(savedir)

    img_name = img_path.split('/')[-1]
    name,ext = os.path.splitext(img_name)
    img = cv2.imread(img_path)
    height,width,layers = img.shape
    with codecs.open(os.path.join(savedir,name+'.xml'), 'w', 'utf-8') as xml:
        xml.write('\n')
        xml.write('\t' + img_name + '\n')
        xml.write('\t\n')
        xml.write('\t\t' + str(width) + '\n')
        xml.write('\t\t' + str(height) + '\n')
        xml.write('\t\t' + str(layers) + '\n')
        xml.write('\t\n')
        for obj in exist_obj_list:
            bbox = obj[:4]
            class_name = obj[-1]
            xmin, ymin, xmax, ymax = bbox
            xml.write('\t\n')
            xml.write('\t\t' + class_name + '\n')
            xml.write('\t\t\n')
            xml.write('\t\t\t' + str(int(xmin)) + '\n')
            xml.write('\t\t\t' + str(int(ymin)) + '\n')
            xml.write('\t\t\t' + str(int(xmax)) + '\n')
            xml.write('\t\t\t' + str(int(ymax)) + '\n')
            xml.write('\t\t\n')
            xml.write('\t\n')
        xml.write('')

def slice(img_dir,xml_dir,img_save_dir,xml_save_dir,m,n,is_label=False): #h维度切m块、w维度切n块,切成mxn块
    if not os.path.exists(img_save_dir):
        os.makedirs(img_save_dir)
    for img_name in os.listdir(img_dir):
        have_xml = False
        name,ext = os.path.splitext(img_name) #分离图片名称和后缀

        xml_path = os.path.join(xml_dir,name+'.xml')
        if os.path.lexists(xml_path) & is_label:
            objects_list = deal_xml(xml_path)
            have_xml = True

        img = cv2.imread(os.path.join(img_dir,img_name))
        h,w = img.shape[:2]
        dy = math.ceil(h/m) #向上取整,避免遗漏像素点
        dx = math.ceil(w/n)
        for i in range(m):
            for j in range(n):
                min_x = j*dx
                min_y = i*dy
                max_x = min(min_x+dx,w)
                max_y = min(min_y+dy,h) #避免切片超出图像范围
                slice_im = img[min_y:max_y,min_x:max_x]
                slice_im_name = name+'_'+str(min_x)+'_'+str(min_y)+'_'+str(max_x)+'_'+str(max_y)+ext
                cv2.imwrite(os.path.join(img_save_dir,slice_im_name),slice_im)
                if have_xml & is_label:
                    objs_list = exist_objs(slice_im_name,objects_list)
                    make_voc(os.path.join(img_save_dir,slice_im_name), xml_save_dir, objs_list)

def merge_labels(img_dir,slice_xml_dir,save_dir): #将切片的检测结果转移到原图
    for img_name in os.listdir(img_dir):
        name,ext = os.path.splitext(img_name)
        objs_list = []
        for xml_name in os.listdir(slice_xml_dir):
            if xml_name.split('_')[0]==name:
                objs_list0 = deal_xml(os.path.join(slice_xml_dir,xml_name))
                if objs_list0 != []:
                    x0 = int(xml_name.split('_')[1])
                    y0 = int(xml_name.split('_')[2])
                    for obj in objs_list0:
                        xmin,ymin,xmax,ymax,category = obj[0],obj[1],obj[2],obj[3],obj[4]
                        xmin = xmin + x0
                        ymin = ymin + y0
                        xmax = xmax + x0
                        ymax = ymax + y0
                        objs_list.append([xmin,ymin,xmax,ymax,category])
        make_voc(os.path.join(img_dir,img_name), save_dir, objs_list)

sample_img_dir = './image/'
slice_img_dir = './slice_results/'
sample_xml_dir = './image_xml/'
slice_xml_dir = './slice_results_xml/'
merge_xml_dir = './merge_results_xml/'

if __name__ == '__main__':
    slice(sample_img_dir, sample_xml_dir,m=3, n=3, img_save_dir=slice_img_dir,xml_save_dir=slice_xml_dir,is_label=True) #is_label意思是是否带标签裁剪
    merge_labels(sample_img_dir,slice_xml_dir,merge_xml_dir)



结果

没有overlap的高分辨率带标签图像切片——工作总结_第1张图片
没有overlap的高分辨率带标签图像切片——工作总结_第2张图片
没有overlap的高分辨率带标签图像切片——工作总结_第3张图片
图一是标注好的原图,为了看的更清楚随便标注了两个大的标签框;
图二是分割后的结果,分割为9份,分割份数作为输入参数可以在函数中修改;
图三是将图二作为检测结果再把标签框复原到原图。

结论

和上一篇相比这个算法没有overlap即重叠区域,这个代码是我自己写的(有些函数参考的是上篇),个人认为比较清晰。两种算法按照检测要求的不同可以选择使用,解决了高分辨率图像小目标情况下检测困难的问题。

你可能感兴趣的:(python,opencv,pytorch)