目标检测DOTA数据集预处理相关函数

目录

1.从DOTA数据集中选出自己需要的类别

2.DOTA数据gt可视化

3.对DOTA数据进行分割

4.分割后处理

5.转换成VOC形式的xml文件

6.对xml形式的数据进行数据扩增


1.从DOTA数据集中选出自己需要的类别

import os
import shutil
import cv2

catogory = ['ship']  #指定类别的名称

def custombasename(fullname):  
    return os.path.basename(os.path.splitext(fullname)[0])  
  
def GetFileFromThisRootDir(dir,ext = None):  
  allfiles = []  
  needExtFilter = (ext != None)  
  for root,dirs,files in os.walk(dir):  
    for filespath in files:  
      filepath = os.path.join(root, filespath)  
      extension = os.path.splitext(filepath)[1][1:]  
      if needExtFilter and extension in ext:  
        allfiles.append(filepath)  
      elif not needExtFilter:  
        allfiles.append(filepath)  
  return allfiles  
  
  
  
if __name__ == '__main__':
    root1 = 'H:/DOTA_biqi/Org_data/DOTA/train'
    pic_path = os.path.join(root1, 'images') #样本图片路径
    label_path = os.path.join(root1, 'labelTxt') #DOTA标签的所在路径
    label_list = GetFileFromThisRootDir(label_path)
    helicopter_pic = 'C:/Users/wytwh/Desktop/ship/train/images'
    helicopter_label = 'C:/Users/wytwh/Desktop/ship/train/labelTxt'
    for labelpath in label_list:
        n = 0
        f = open(labelpath,'r')
        lines = f.readlines()
        splitlines = [x.strip().split(' ') for x in lines]  #根据空格分割
        for i, splitline  in enumerate(splitlines):
            if i in [0,1]:  #DOTA数据集前两行对于我们来说是无用的
                continue
            catogory_name = splitline[8]  #类别名称
            if catogory_name in catogory:
                n = n+1
                if n>2:   #样本包含两个及以上的再挑选出来
                    name = custombasename(labelpath)  #名称
                    oldlabelpath = labelpath 
                    oldimgpath = os.path.join(pic_path, name+ '.png') 
                    img = cv2.imread(oldimgpath) 
                    newlabelpath = os.path.join(helicopter_label, name+'.txt')            
                    newimage_path = os.path.join(helicopter_pic, name + '.tif')  #如果要改变图像的后缀,就采用重写的方法           
                    cv2.imwrite(newimage_path, img)
                    #shutil.copyfile(oldimgpath, newimage_path)
                    shutil.copyfile(oldlabelpath, newlabelpath)  
                    break

2.DOTA数据gt可视化

   见博客目标检测可视化gt

3.对DOTA数据进行分割

我修改的地方是,对于分割后被截断的目标,如果该目标与完整的目标比>thresh(0.7),那么就被保留,否则就不保留。

import os
import codecs
import numpy as np
import math
from dota_utils import GetFileFromThisRootDir
import cv2
import shapely.geometry as shgeo
import dota_utils as util
import copy


def choose_best_pointorder_fit_another(poly1, poly2):
    """
        To make the two polygons best fit with each point
    """
    x1 = poly1[0]
    y1 = poly1[1]
    x2 = poly1[2]
    y2 = poly1[3]
    x3 = poly1[4]
    y3 = poly1[5]
    x4 = poly1[6]
    y4 = poly1[7]
    combinate = [np.array([x1, y1, x2, y2, x3, y3, x4, y4]), np.array([x2, y2, x3, y3, x4, y4, x1, y1]),
                 np.array([x3, y3, x4, y4, x1, y1, x2, y2]), np.array([x4, y4, x1, y1, x2, y2, x3, y3])]
    dst_coordinate = np.array(poly2)
    distances = np.array([np.sum((coord - dst_coordinate)**2) for coord in combinate])
    sorted = distances.argsort()
    return combinate[sorted[0]]

def cal_line_length(point1, point2):
    return math.sqrt( math.pow(point1[0] - point2[0], 2) + math.pow(point1[1] - point2[1], 2))


class splitbase():
    def __init__(self,
                 basepath,
                 outpath,
                 code = 'utf-8',
                 gap=100,
                 subsize=1024,
                 thresh=0.7,
                 choosebestpoint=True,
                 ):
        """
        :param basepath: base path for dota data
        :param outpath: output base path for dota data,
        the basepath and outputpath have the similar subdirectory, 'images' and 'labelTxt'
        :param code: encodeing format of txt file
        :param gap: overlap between two patches  子图间的重叠(防止目标被截断)
        :param subsize: subsize of patch   子图的大小
        :param thresh: the thresh determine whether to keep the instance if the instance is cut down in the process of split
        :param choosebestpoint: used to choose the first point for the
        :param ext: ext for the image format
        """
        self.basepath = basepath
        self.outpath = outpath
        self.code = code
        self.gap = gap
        self.subsize = subsize
        self.slide = self.subsize - self.gap
        self.thresh = thresh
        self.imagepath = os.path.join(self.basepath, 'images')
        self.labelpath = os.path.join(self.basepath, 'labelTxt')
        self.outimagepath = os.path.join(self.outpath, 'images')
        self.outlabelpath = os.path.join(self.outpath, 'labelTxt')
        self.choosebestpoint = choosebestpoint
        if not os.path.exists(self.outimagepath):
            os.makedirs(self.outimagepath)
        if not os.path.exists(self.outlabelpath):
            os.makedirs(self.outlabelpath)

    ## point: (x, y), rec: (xmin, ymin, xmax, ymax)
    # def __del__(self):
    #     self.f_sub.close()
    ## grid --> (x, y) position of grids
    def polyorig2sub(self, left, up, poly):
        polyInsub = np.zeros(len(poly))
        for i in range(int(len(poly)/2)):
            polyInsub[i * 2] = int(poly[i * 2] - left)
            polyInsub[i * 2 + 1] = int(poly[i * 2 + 1] - up)
        return polyInsub

    def calchalf_iou(self, poly1, poly2):
        """
            It is not the iou on usual, the iou is the value of intersection over poly1
        """
        inter_poly = poly1.intersection(poly2)
        inter_area = inter_poly.area
        poly1_area = poly1.area
        half_iou = inter_area / poly1_area
        return inter_poly, half_iou

    def saveimagepatches(self, img, subimgname, left, up, ext):
        subimg = copy.deepcopy(img[up: (up + self.subsize), left: (left + self.subsize)])
        outdir = os.path.join(self.outimagepath, subimgname + ext)
        cv2.imwrite(outdir, subimg)

    def GetPoly4FromPoly5(self, poly):
        distances = [cal_line_length((poly[i * 2], poly[i * 2 + 1] ), (poly[(i + 1) * 2], poly[(i + 1) * 2 + 1])) for i in range(int(len(poly)/2 - 1))]
        distances.append(cal_line_length((poly[0], poly[1]), (poly[8], poly[9])))
        pos = np.array(distances).argsort()[0]
        count = 0
        outpoly = []
        while count < 5:
            #print('count:', count)
            if (count == pos):
                outpoly.append((poly[count * 2] + poly[(count * 2 + 2)%10])/2)
                outpoly.append((poly[(count * 2 + 1)%10] + poly[(count * 2 + 3)%10])/2)
                count = count + 1
            elif (count == (pos + 1)%5):
                count = count + 1
                continue

            else:
                outpoly.append(poly[count * 2])
                outpoly.append(poly[count * 2 + 1])
                count = count + 1
        return outpoly

    def savepatches(self, resizeimg, objects, subimgname, left, up, right, down, ext):
        outdir = os.path.join(self.outlabelpath, subimgname + '.txt')
        mask_poly = []
        imgpoly = shgeo.Polygon([(left, up), (right, up), (right, down),
                                 (left, down)])
        with codecs.open(outdir, 'w', self.code) as f_out:
            for obj in objects:
                gtpoly = shgeo.Polygon([(obj['poly'][0], obj['poly'][1]),
                                         (obj['poly'][2], obj['poly'][3]),
                                         (obj['poly'][4], obj['poly'][5]),
                                         (obj['poly'][6], obj['poly'][7])])
                if (gtpoly.area <= 0):
                    continue
                inter_poly, half_iou = self.calchalf_iou(gtpoly, imgpoly)

                # print('writing...')
                if (half_iou == 1):
                    polyInsub = self.polyorig2sub(left, up, obj['poly'])
                    outline = ' '.join(list(map(str, polyInsub)))
                    outline = outline + ' ' + obj['name'] + ' ' + str(obj['difficult'])
                    f_out.write(outline + '\n')
                elif (half_iou > 0):
                #elif (half_iou > self.thresh):
                  ##  print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
                    inter_poly = shgeo.polygon.orient(inter_poly, sign=1)
                    out_poly = list(inter_poly.exterior.coords)[0: -1]
                    if len(out_poly) < 4:
                        continue

                    out_poly2 = []
                    for i in range(len(out_poly)):
                        out_poly2.append(out_poly[i][0])
                        out_poly2.append(out_poly[i][1])

                    if (len(out_poly) == 5):
                        #print('==========================')
                        out_poly2 = self.GetPoly4FromPoly5(out_poly2)
                    elif (len(out_poly) > 5):
                        """
                            if the cut instance is a polygon with points more than 5, we do not handle it currently
                        """
                        continue
                    if (self.choosebestpoint):
                        out_poly2 = choose_best_pointorder_fit_another(out_poly2, obj['poly'])

                    polyInsub = self.polyorig2sub(left, up, out_poly2)

                    for index, item in enumerate(polyInsub):
                        if (item <= 1):
                            polyInsub[index] = 1
                        elif (item >= self.subsize):
                            polyInsub[index] = self.subsize
                    outline = ' '.join(list(map(str, polyInsub)))
                    if (half_iou > self.thresh):
                        outline = outline + ' ' + obj['name'] + ' ' + str(obj['difficult'])
                    #else:
                        ## if the left part is too small, label as '2'
                        #outline = outline + ' ' + obj['name'] + ' ' + '2'
                        f_out.write(outline + '\n') #对于分割后不足thresh的目标(difficult==2)的除去
                    #f_out.write(outline + '\n')
                #else:
                 #   mask_poly.append(inter_poly)
        self.saveimagepatches(resizeimg, subimgname, left, up, ext)

    def SplitSingle(self, imgpath, rate):
        """
            split a single image and ground truth
        :param name: image name
        :param rate: the resize scale for the image
        :param extent: the image format
        :return:
        """
        img = cv2.imread(imgpath)
        name = util.custombasename(imgpath)  #得到图片的名称
        extent = os.path.splitext(imgpath)[-1] #得到图片的后缀
        if np.shape(img) == ():
            return
        fullname = os.path.join(self.labelpath, name + '.txt')
        objects = util.parse_dota_poly2(fullname)
        for obj in objects:
            obj['poly'] = list(map(lambda x:rate*x, obj['poly']))
            #obj['poly'] = list(map(lambda x: ([2 * y for y in x]), obj['poly']))

        if (rate != 1):
            resizeimg = cv2.resize(img, None, fx=rate, fy=rate, interpolation = cv2.INTER_CUBIC)
        else:
            resizeimg = img
        outbasename = name + '__' + str(rate) + '__'
        weight = np.shape(resizeimg)[1]
        height = np.shape(resizeimg)[0]

        left, up = 0, 0
        while (left < weight):
            if (left + self.subsize >= weight):
                left = max(weight - self.subsize, 0)
            up = 0
            while (up < height):
                if (up + self.subsize >= height):
                    up = max(height - self.subsize, 0)
                right = min(left + self.subsize, weight - 1)
                down = min(up + self.subsize, height - 1)
                subimgname = outbasename + str(left) + '___' + str(up)
                # self.f_sub.write(name + ' ' + subimgname + ' ' + str(left) + ' ' + str(up) + '\n')
                self.savepatches(resizeimg, objects, subimgname, left, up, right, down, extent)
                if (up + self.subsize >= height):
                    break
                else:
                    up = up + self.slide
            if (left + self.subsize >= weight):
                break
            else:
                left = left + self.slide

    def splitdata(self, rate):
        """
        :param rate: resize rate before cut
        """
        imagelists = GetFileFromThisRootDir(self.imagepath)
        for imgpath in imagelists:
            print('正在处理 %s'%imgpath)
            self.SplitSingle(imgpath, rate)

if __name__ == '__main__':
    # example usage of ImgSplit
    split = splitbase(r'/home/yantianwang/lala/ship/train',
                       r'/home/yantianwang/lala/ship/train/examplesplit')
    split.splitdata(1)

4.分割后处理

4.1.除去分割后的空白样本

import os
import shutil

def custombasename(fullname):
    return os.path.basename(os.path.splitext(fullname)[0])

def GetFileFromThisRootDir(dir,ext = None):
  allfiles = []
  needExtFilter = (ext != None)
  for root,dirs,files in os.walk(dir):
    for filespath in files:
      filepath = os.path.join(root, filespath)
      extension = os.path.splitext(filepath)[1][1:]
      if needExtFilter and extension in ext:
        allfiles.append(filepath)
      elif not needExtFilter:
        allfiles.append(filepath)
  return allfiles
  
def cleandata(path, img_path, blank_label_path, blank_img_path, ext):
    name = custombasename(path)  #名称
    f_in =  open(path, 'r')  #打开label文件
    lines = f_in.readlines()
    if len(lines) == 0:  #如果为空
        f_in.close()
        image_path = os.path.join(img_path, name + ext) #样本图片的名称
        shutil.move(image_path, blank_img_path)  #移动该样本图片到blank_img_path
        shutil.move(path, blank_label_path)     #移动该样本图片的标签到blank_label_path
    print('正在处理 %s'%path)
            
                                           
if __name__ == '__main__':
    root = '/home/yantianwang/lala/ship/train/examplesplit'
    img_path = os.path.join(root, 'images')  #分割后的样本集
    label_path = os.path.join(root, 'labelTxt')  #分割后的标签
    ext = '.tif' #图片的后缀
    #空白的样本及标签
    blank_img_path = os.path.join(root, 'blank_images')
    blank_label_path = os.path.join(root, 'blank_labelTxt')
    if not os.path.exists(blank_img_path):
        os.makedirs(blank_img_path)
    if not os.path.exists(blank_label_path):
        os.makedirs(blank_label_path)
        
    label_list = GetFileFromThisRootDir(label_path)
    for path in label_list:
        cleandata(path, img_path, blank_label_path, blank_img_path, ext)

4.2.除去分割后已经不含我们需要的目标的样本

修改一下1.中的代码即可

5.转换成VOC形式的xml文件

见博客将DOTA标签格式转为VOC格式形成xml文件

6.对xml形式的数据进行数据扩增

见博客目标检测数据扩增

你可能感兴趣的:(python)