Visdrone数据集 | 数据前&后处理操作系列

Visdrone预处理

  • 一、把Visdrone数据集进行切割,生成600*600大小的图片和xml文件
  • 二、提取图片名字
  • 三、将xml标注文件转成yolo需要的标注格式
  • 四、Visdrone直接转成yolo格式
  • 五、使用xml格式画框
  • 六、将txt格式转换成xml格式
  • 七、将原始图像变成二值图像(目标存在是白色,背景是黑色)
  • 八、绘制PR曲线

本篇博客包含:

  1. 切割数据集
  2. 将数据做成Yolo格式
  3. 图像二值化
  4. 绘制评估曲线

github:https://github.com/mary-0830/visdrone-dataset

一、把Visdrone数据集进行切割,生成600*600大小的图片和xml文件

train_crop_visdrone.py

import os
import scipy.misc as misc
from xml.dom.minidom import Document
import numpy as np
import copy, cv2

def save_to_txt(save_path, objects_axis):             
    f = open(save_path,'w')
    objects_list = objects_axis.tolist()
    objects_ = [','.join(map(str, i)) + '\n' for i in objects_list]
    objects_[-1] = objects_[-1][:-1]
    # import pdb
    # pdb.set_trace()
    f.writelines(objects_)
    f.close() 


def save_to_xml(save_path, im_width, im_height, objects_axis, label_name, name, hbb=True):
    im_depth = 0
    object_num = len(objects_axis)
    doc = Document()

    annotation = doc.createElement('annotation')
    doc.appendChild(annotation)

    folder = doc.createElement('folder')
    folder_name = doc.createTextNode('Visdrone')
    folder.appendChild(folder_name)
    annotation.appendChild(folder)

    filename = doc.createElement('filename')
    filename_name = doc.createTextNode(name)
    filename.appendChild(filename_name)
    annotation.appendChild(filename)

    source = doc.createElement('source')
    annotation.appendChild(source)

    database = doc.createElement('database')
    database.appendChild(doc.createTextNode('The Visdrone Database'))
    source.appendChild(database)

    annotation_s = doc.createElement('annotation')
    annotation_s.appendChild(doc.createTextNode('Visdrone'))
    source.appendChild(annotation_s)

    image = doc.createElement('image')
    image.appendChild(doc.createTextNode('flickr'))
    source.appendChild(image)

    flickrid = doc.createElement('flickrid')
    flickrid.appendChild(doc.createTextNode('322409915'))
    source.appendChild(flickrid)

    owner = doc.createElement('owner')
    annotation.appendChild(owner)

    flickrid_o = doc.createElement('flickrid')
    flickrid_o.appendChild(doc.createTextNode('knautia'))
    owner.appendChild(flickrid_o)

    name_o = doc.createElement('name')
    name_o.appendChild(doc.createTextNode('yang'))
    owner.appendChild(name_o)


    size = doc.createElement('size')
    annotation.appendChild(size)
    width = doc.createElement('width')
    width.appendChild(doc.createTextNode(str(im_width)))
    height = doc.createElement('height')
    height.appendChild(doc.createTextNode(str(im_height)))
    depth = doc.createElement('depth')
    depth.appendChild(doc.createTextNode(str(im_depth)))
    size.appendChild(width)
    size.appendChild(height)
    size.appendChild(depth)
    segmented = doc.createElement('segmented')
    segmented.appendChild(doc.createTextNode('0'))
    annotation.appendChild(segmented)
    for i in range(object_num):
        objects = doc.createElement('object')
        annotation.appendChild(objects)
        object_name = doc.createElement('name')
        object_name.appendChild(doc.createTextNode(label_name[int(objects_axis[i][5])]))
        objects.appendChild(object_name)
        pose = doc.createElement('pose')
        pose.appendChild(doc.createTextNode('Unspecified'))
        objects.appendChild(pose)
        truncated = doc.createElement('truncated')
        truncated.appendChild(doc.createTextNode('1'))
        objects.appendChild(truncated)
        difficult = doc.createElement('difficult')
        difficult.appendChild(doc.createTextNode('0'))
        objects.appendChild(difficult)
        bndbox = doc.createElement('bndbox')
        objects.appendChild(bndbox)
        if hbb:
           x0 = doc.createElement('xmin')
           x0.appendChild(doc.createTextNode(str((objects_axis[i][0]))))
           bndbox.appendChild(x0)
           y0 = doc.createElement('ymin')
           y0.appendChild(doc.createTextNode(str((objects_axis[i][1]))))
           bndbox.appendChild(y0)
           x1 = doc.createElement('xmax')
           x1.appendChild(doc.createTextNode(str((objects_axis[i][2]))))
           bndbox.appendChild(x1)
           y1 = doc.createElement('ymax')
           y1.appendChild(doc.createTextNode(str((objects_axis[i][3]))))
           bndbox.appendChild(y1)       
        else:

            x0 = doc.createElement('x0')
            x0.appendChild(doc.createTextNode(str((objects_axis[i][0]))))
            bndbox.appendChild(x0)
            y0 = doc.createElement('y0')
            y0.appendChild(doc.createTextNode(str((objects_axis[i][1]))))
            bndbox.appendChild(y0)

            x1 = doc.createElement('x1')
            x1.appendChild(doc.createTextNode(str((objects_axis[i][2]))))
            bndbox.appendChild(x1)
            y1 = doc.createElement('y1')
            y1.appendChild(doc.createTextNode(str((objects_axis[i][3]))))
            bndbox.appendChild(y1)
            
            x2 = doc.createElement('x2')
            x2.appendChild(doc.createTextNode(str((objects_axis[i][4]))))
            bndbox.appendChild(x2)
            y2 = doc.createElement('y2')
            y2.appendChild(doc.createTextNode(str((objects_axis[i][5]))))
            bndbox.appendChild(y2)

            x3 = doc.createElement('x3')
            x3.appendChild(doc.createTextNode(str((objects_axis[i][6]))))
            bndbox.appendChild(x3)
            y3 = doc.createElement('y3')
            y3.appendChild(doc.createTextNode(str((objects_axis[i][7]))))
            bndbox.appendChild(y3)
        
    f = open(save_path,'w')
    f.write(doc.toprettyxml(indent = ''))
    f.close() 

class_list = ['ignored regions','pedestrian','people','bicycle','car','van','truck','tricycle','awning-tricycle','bus','motor','others']


def format_label(txt_list):
    format_data = []
    for i in txt_list[0:]:
        format_data.append(
        [int(xy) for xy in i.split(',')[:8]] 
        # {'x0': int(i.split(' ')[0]),
        # 'x1': int(i.split(' ')[2]),
        # 'x2': int(i.split(' ')[4]),
        # 'x3': int(i.split(' ')[6]),
        # 'y1': int(i.split(' ')[1]),
        # 'y2': int(i.split(' ')[3]),
        # 'y3': int(i.split(' ')[5]),
        # 'y4': int(i.split(' ')[7]),
        # 'class': class_list.index(i.split(' ')[8]) if i.split(' ')[8] in class_list else 0, 
        # 'difficulty': int(i.split(' ')[9])}
        )
        # if i.split(',')[8] not in class_list :
        #     print ('warning found a new label :', i.split(',')[8])
        #     exit()
    return np.array(format_data)

def clip_image(file_idx, image, boxes_all, width, height, stride_w, stride_h):
    if len(boxes_all) > 0:
        shape = image.shape
        for start_h in range(0, shape[0], stride_h):
            for start_w in range(0, shape[1], stride_w):
                boxes = copy.deepcopy(boxes_all)
                box = np.zeros_like(boxes_all)
                start_h_new = start_h
                start_w_new = start_w
                if start_h + height > shape[0]:
                  start_h_new = shape[0] - height
                if start_w + width > shape[1]:
                  start_w_new = shape[1] - width
                top_left_row = max(start_h_new, 0)
                top_left_col = max(start_w_new, 0)
                bottom_right_row = min(start_h + height, shape[0])
                bottom_right_col = min(start_w + width, shape[1])

                subImage = image[top_left_row:bottom_right_row, top_left_col: bottom_right_col]

                box[:, 0] = boxes[:, 0]- top_left_col
                box[:, 2] = boxes[:, 0] + boxes[:, 2]- top_left_col 
                box[:, 4] = boxes[:, 4]
                box[:, 0] = [max(i, 0) for i in box[:, 0]]  # 限制框的大小
                
                # box[:, 6] = boxes[:, 6] - top_left_col

                box[:, 1] = boxes[:, 1] - top_left_row 
                box[:, 3] = boxes[:, 1] + boxes[:, 3] - top_left_row 
                box[:, 5] = boxes[:, 5]
                box[:, 1] = [max(i, 0) for i in box[:, 1]]
                # box[:, 7] = boxes[:, 7] - top_left_row
                # box[:, 8] = boxes[:, 8]
                center_y = 0.5*(box[:, 1] + box[:, 3])
                center_x = 0.5*(box[:, 0] + box[:, 2])
                # print('center_y', center_y)
                # print('center_x', center_x)
                # print ('boxes', boxes)
                # print ('boxes_all', boxes_all)
                # print ('top_left_col', top_left_col, 'top_left_row', top_left_row)
                
                cond1 = np.intersect1d(np.where(center_y[:]>=0)[0], np.where(center_x[:]>=0 )[0])
                cond2 = np.intersect1d(np.where(center_y[:] <= (bottom_right_row - top_left_row))[0],
                                        np.where(center_x[:] <= (bottom_right_col - top_left_col))[0])
                idx = np.intersect1d(cond1, cond2)
                # idx = np.where(center_y[:]>=0 and center_x[:]>=0 and center_y[:] <= (bottom_right_row - top_left_row) and center_x[:] <= (bottom_right_col - top_left_col))[0]
                # save_path, im_width, im_height, objects_axis, label_name
                if len(idx) > 0:
                    name="%s_%04d_%04d.jpg" % (file_idx, top_left_row, top_left_col)
                    print(name)
                    xml = os.path.join(save_dir, 'annotations_600_xml', "%s_%04d_%04d.xml" % (file_idx, top_left_row, top_left_col))
                    save_to_xml(xml, subImage.shape[1], subImage.shape[0], box[idx, :], class_list, str(name))
                    # save_to_txt(xml, box[idx, :])
                    # print ('save xml : ', xml)
                    if subImage.shape[0] > 5 and subImage.shape[1] >5:
                        img = os.path.join(save_dir, 'images_600', "%s_%04d_%04d.jpg" % (file_idx, top_left_row, top_left_col))
                        cv2.imwrite(img, subImage)


print ('class_list', len(class_list))
raw_data = 'D:/datasets/VisDrone/VisDrone2019-DET-val/'
raw_images_dir = os.path.join(raw_data, 'images')
raw_label_dir = os.path.join(raw_data, 'annotations')

save_dir = 'D:/datasets/VisDrone/VisDrone2019-DET-val/' 

images = [i for i in os.listdir(raw_images_dir) if 'jpg' in i]
labels = [i for i in os.listdir(raw_label_dir) if 'txt' in i]

print ('find image', len(images))
print ('find label', len(labels))

min_length = 1e10
max_length = 1
img_h, img_w, stride_h, stride_w = 600, 600, 450, 450 

for idx, img in enumerate(images):
# img = 'P1524.png'
    
    img_data = misc.imread(os.path.join(raw_images_dir, img))
    print (idx, 'read image', img)

    # if len(img_data.shape) == 2:
    #     img_data = img_data[:, :, np.newaxis]
    #     print ('find gray image')

    txt_data = open(os.path.join(raw_label_dir, img.replace('jpg', 'txt')), 'r').readlines()
    # print (idx, len(format_label(txt_data)), img_data.shape)
    # if max(img_data.shape[:2]) > max_length:
        # max_length = max(img_data.shape[:2])
    # if min(img_data.shape[:2]) < min_length:
        # min_length = min(img_data.shape[:2])
    # if idx % 50 ==0:
        # print (idx, len(format_label(txt_data)), img_data.shape)
        # print (idx, 'min_length', min_length, 'max_length', max_length)
    box = format_label(txt_data)
    # box = dele(box)
    clip_image(img.strip('.jpg'), img_data, box, img_h, img_w, stride_h, stride_w)
    
#     rm val/images/*   &&   rm val/labeltxt/*

二、提取图片名字

extract_name.py

# P02 批量读取文件名(不带后缀)

import os

file_path = "D:/datasets/VisDrone/VisDrone2019-DET-val/annotations_600/"
path_list = os.listdir(file_path)  # os.listdir(file)会历遍文件夹内的文件并返回一个列表
print(path_list)
path_name = []  # 把文件列表写入save.txt中


def saveList(pathName):
    for file_name in pathName:
        with open("name_600_val.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")


def dirList(path_list):
    for i in range(0, len(path_list)):
        path = os.path.join(file_path, path_list[i])
    if os.path.isdir(path):
        saveList(os.listdir(path))


dirList(path_list)
saveList(path_list)

三、将xml标注文件转成yolo需要的标注格式

xml2yolo.py

# 缺陷坐标xml转txt

import xml.etree.ElementTree as ET
import os


classes = ['ignored regions','pedestrian','people','bicycle','car','van','truck','tricycle','awning-tricycle','bus','motor','others']  # 输入缺陷名称,必须与xml标注名称一致


train_file = 'images_val_600_test.txt'  
train_file_txt = ''

wd = os.getcwd()

def convert(size, box):
    dw = 1. / size[0]
    dh = 1. / size[1]
    box = list(box)
    box[1] = min(box[1], size[0])   # 限制目标的范围在图片尺寸内
    box[3] = min(box[3], size[1])
    x = ((box[0] + box[1]) / 2.0) * dw
    y = ((box[2] + box[3]) / 2.0) * dh
    w = (box[1] - box[0]) * dw
    h = (box[3] - box[2]) * dh
    return (x, y, w, h)   


def convert_annotation(image_id):
    in_file = open('D:/datasets/VisDrone/VisDrone2019-DET-val/annotations_600_xml/%s.xml' % (image_id))  # 读取xml文件路径

    out_file = open('D:/datasets/VisDrone/labels_val_600/%s.txt' % (image_id), 'w')  # 需要保存的txt格式文件路径
    tree = ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    for obj in root.iter('object'):
        cls = obj.find('name').text
        if cls not in classes:  # 检索xml中的缺陷名称
            continue
        cls_id = classes.index(cls)
        # import pdb
        # pdb.set_trace()
        if cls_id == 0 or cls_id == 11:
            continue
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
             float(xmlbox.find('ymax').text))
        bb = convert((w, h), b)
        out_file.write(str(cls_id - 1) + " " + " ".join([str(a) for a in bb]) + '\n')


image_ids_train = open('D:/datasets/VisDrone/name_600_val.txt').read().strip().split()  # 读取xml文件名索引

for image_id in image_ids_train:
    convert_annotation(image_id)

anns = os.listdir('./VisDrone2019-DET-val/annotations_600_xml/')
for ann in anns:
    ans = ''
    outpath = wd + '/labels_val_600/' + ann
    if ann[-3:] != 'xml':
        continue
    train_file_txt = train_file_txt + wd + '/VisDrone2019-DET-val' +  '/images_600/' + ann[:-3] + 'jpg\n'

with open(train_file, 'w') as outfile:
    outfile.write(train_file_txt)

四、Visdrone直接转成yolo格式

trans_yolo.py

import os
from pathlib import Path
from PIL import Image
import csv


def convert(size, box):
    dw = 1. / size[0]
    dh = 1. / size[1]
    x = (box[0] + box[2] / 2) * dw
    y = (box[1] + box[3] / 2) * dh
    w = box[2] * dw
    h = box[3] * dh
    return (x, y, w, h)
            
wd = os.getcwd()

if not os.path.exists('labels_val'):
    os.makedirs('labels_val')


train_file = 'images_val.txt'  
train_file_txt = ''
    
anns = os.listdir('./VisDrone2019-DET-val/annotations')
for ann in anns:
    ans = ''
    outpath = wd + '/labels_val/' + ann
    if ann[-3:] != 'txt':
        continue
    with Image.open(wd + './VisDrone2019-DET-val/images/' + ann[:-3] + 'jpg') as Img:
        img_size = Img.size
    with open(wd + './VisDrone2019-DET-val/annotations/' + ann, newline='') as csvfile:
        spamreader = csv.reader(csvfile)
        # import pdb
        # pdb.set_trace()
        for row in spamreader:
            if row[4] == '0':
                continue
            bb = convert(img_size, tuple(map(int, row[:4])))
            ans = ans + str(int(row[5])-1) + ' ' + ' '.join(str(a) for a in bb) + '\n'
            with open(outpath, 'w') as outfile:
                outfile.write(ans)
    train_file_txt = train_file_txt + wd + '/images/' + ann[:-3] + 'jpg\n'

with open(train_file, 'w') as outfile:
    outfile.write(train_file_txt)

五、使用xml格式画框

draw_visdrone.py

import os
import os.path
import xml.etree.cElementTree as ET
import cv2
def draw(image_path, xml_path, root_saved_path):
    """
    图片根据标注画框
    """
    src_img_path = image_path
    src_ann_path = xml_path
    for file in os.listdir(src_ann_path):
        # print(file)
        file_name, suffix = os.path.splitext(file)
        # import pdb
        # pdb.set_trace()
        if suffix == '.xml':
            # print(file)
            xml_path = os.path.join(src_ann_path, file)
            image_path = os.path.join(src_img_path, file_name+'.jpg')
            img = cv2.imread(image_path)
            tree = ET.parse(xml_path)
            root = tree.getroot()
            # import pdb
            # pdb.set_trace()
            for obj in root.iter('object'):
                name = obj.find('name').text
                xml_box = obj.find('bndbox')
                x1 = int(xml_box.find('xmin').text)
                x2 = int(xml_box.find('xmax').text)
                y1 = int(xml_box.find('ymin').text)
                y2 = int(xml_box.find('ymax').text)
                cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), thickness=2)
                # 字为绿色
                # cv2.putText(img, name, (x1, y1), cv2.FONT_HERSHEY_COMPLEX, 0.7, (0, 255, 0), thickness=2)
            cv2.imwrite(os.path.join(root_saved_path, file_name+'.jpg'), img)


if __name__ == '__main__':
    image_path = "D:/datasets/VisDrone/VisDrone2019-DET-train/images_600"
    xml_path = "D:/datasets/VisDrone/VisDrone2019-DET-train/annotations_600"
    root_saved_path = "D:/datasets/VisDrone/VisDrone2019-DET-train/result"
    draw(image_path, xml_path, root_saved_path)

六、将txt格式转换成xml格式

txt2xml.py

# coding: utf-8
# author: HXY
# 2020-4-17

"""
该脚本用于visdrone数据处理;
将annatations文件夹中的txt标签文件转换为XML文件;
txt标签内容为:
,,,,,,,
类别:
ignored regions(0), pedestrian(1),
people(2), bicycle(3), car(4), van(5),
truck(6), tricycle(7), awning-tricycle(8),
bus(9), motor(10), others(11)
"""

import os
import cv2
import time
from xml.dom import minidom

name_dict = {'0': 'ignored regions', '1': 'pedestrian', '2': 'people',
             '3': 'bicycle', '4': 'car', '5': 'van', '6': 'truck',
             '7': 'tricycle', '8': 'awning-tricycle', '9': 'bus',
             '10': 'motor', '11': 'others'}


def transfer_to_xml(pic, txt, file_name):
    xml_save_path = 'D:/datasets/VisDrone/VisDrone2019-DET-val/annotations_xml'  # 生成的xml文件存储的文件夹
    if not os.path.exists(xml_save_path):
        os.mkdir(xml_save_path)

    img = cv2.imread(pic)
    img_w = img.shape[1]
    img_h = img.shape[0]
    img_d = img.shape[2]
    doc = minidom.Document()

    annotation = doc.createElement("annotation")
    doc.appendChild(annotation)
    folder = doc.createElement('folder')
    folder.appendChild(doc.createTextNode('visdrone'))
    annotation.appendChild(folder)

    filename = doc.createElement('filename')
    filename.appendChild(doc.createTextNode(file_name))
    annotation.appendChild(filename)

    source = doc.createElement('source')
    database = doc.createElement('database')
    database.appendChild(doc.createTextNode("Unknown"))
    source.appendChild(database)

    annotation.appendChild(source)

    size = doc.createElement('size')
    width = doc.createElement('width')
    width.appendChild(doc.createTextNode(str(img_w)))
    size.appendChild(width)
    height = doc.createElement('height')
    height.appendChild(doc.createTextNode(str(img_h)))
    size.appendChild(height)
    depth = doc.createElement('depth')
    depth.appendChild(doc.createTextNode(str(img_d)))
    size.appendChild(depth)
    annotation.appendChild(size)

    segmented = doc.createElement('segmented')
    segmented.appendChild(doc.createTextNode("0"))
    annotation.appendChild(segmented)

    with open(txt, 'r') as f:
        lines = [f.readlines()]
        for line in lines:
            for boxes in line:
                box = boxes.strip('/n')
                box = box.split(',')
                x_min = box[0]
                y_min = box[1]
                x_max = int(box[0]) + int(box[2])
                y_max = int(box[1]) + int(box[3])
                object_name = name_dict[box[5]]

                # if object_name is 'ignored regions' or 'others':
                #     continue

                object = doc.createElement('object')
                nm = doc.createElement('name')
                nm.appendChild(doc.createTextNode(object_name))
                object.appendChild(nm)
                pose = doc.createElement('pose')
                pose.appendChild(doc.createTextNode("Unspecified"))
                object.appendChild(pose)
                truncated = doc.createElement('truncated')
                truncated.appendChild(doc.createTextNode("1"))
                object.appendChild(truncated)
                difficult = doc.createElement('difficult')
                difficult.appendChild(doc.createTextNode("0"))
                object.appendChild(difficult)
                bndbox = doc.createElement('bndbox')
                xmin = doc.createElement('xmin')
                xmin.appendChild(doc.createTextNode(x_min))
                bndbox.appendChild(xmin)
                ymin = doc.createElement('ymin')
                ymin.appendChild(doc.createTextNode(y_min))
                bndbox.appendChild(ymin)
                xmax = doc.createElement('xmax')
                xmax.appendChild(doc.createTextNode(str(x_max)))
                bndbox.appendChild(xmax)
                ymax = doc.createElement('ymax')
                ymax.appendChild(doc.createTextNode(str(y_max)))
                bndbox.appendChild(ymax)
                object.appendChild(bndbox)
                annotation.appendChild(object)
                with open(os.path.join(xml_save_path, file_name + '.xml'), 'w') as x:
                    x.write(doc.toprettyxml())
                x.close()
    f.close()


if __name__ == '__main__':
    t = time.time()
    print('Transfer .txt to .xml...ing....')
    txt_folder = 'D:/datasets/VisDrone/VisDrone2019-DET-val/annotations'  # visdrone txt标签文件夹
    txt_file = os.listdir(txt_folder)
    img_folder = 'D:/datasets/VisDrone/VisDrone2019-DET-val/images'  # visdrone 照片所在文件夹

    for txt in txt_file:
        txt_full_path = os.path.join(txt_folder, txt)
        img_full_path = os.path.join(img_folder, txt.split('.')[0] + '.jpg')

        try:
            transfer_to_xml(img_full_path, txt_full_path, txt.split('.')[0])
        except Exception as e:
            print(e)

    print("Transfer .txt to .XML sucessed. costed: {:.3f}s...".format(time.time() - t))

七、将原始图像变成二值图像(目标存在是白色,背景是黑色)

img2binary.py

# -*- coding:utf-8 -*-
# 把标注目标映射为二值图像

import matplotlib.pyplot as plt
import cv2, os
from xml.dom import minidom
import xml.etree.ElementTree as ET
import numpy as np

def tobinary(img_path):
    
    img_list = os.listdir(img_path)
    for img_name_id in img_list:
        # import pdb
        # pdb.set_trace()
        img_id, _ = os.path.splitext(img_name_id)
        img_file = os.path.join(img_path, img_id + ".jpg")
        image = cv2.imread(img_file)
        xml_file = os.path.join(xml_path, img_id + ".xml")
        weight, height = image.shape[0], image.shape[1]
        bg = np.zeros((weight, height), dtype = np.uint8)
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for obj in root.iter("object"):
            name = obj.find("name").text
            xmlbox = obj.find("bndbox")
            xmin = int(xmlbox.find('xmin').text)
            ymin = int(xmlbox.find('ymin').text)
            xmax = int(xmlbox.find('xmax').text)
            ymax = int(xmlbox.find('ymax').text)
            color = (255, 255, 255)
            cv2.rectangle(bg, (xmin, ymin), (xmax, ymax), color, -1)


        img_name = img_id + ".jpg"
        out_file = os.path.join(out_path, img_name)
        cv2.imwrite(out_file, bg)


if __name__ == "__main__":
    img_path = "D:/datasets/VisDrone/VisDrone2019-DET-train/images/"
    xml_path = "D:/datasets/VisDrone/VisDrone2019-DET-train/annotations_xml/"
    out_path = "D:/datasets/VisDrone/VisDrone2019-DET-train/output/"

    tobinary(img_path)
    # xml = open("D:/datasets/VisDrone/VisDrone2019-DET-train/annotations_xml/0000001_02999_d_0000005.xml")
    # import pdb
    # pdb.set_trace()
    

八、绘制PR曲线

draw.py

from pr import *
clas= ['pedestrian','people','bicycle','car','van','truck',
        'tricycle','awning-tricycle','bus','motor'] # 类别
visdrone_file=["yolov5.txt", "centernet.txt", "ucgnet.txt"] # 放在Prediction里面"Faster_RCNN.txt"
visdrone_algorithm=["yolov5", "centernet", "ucgnet"] # 对应上面的文件的算法名称
use_07_metric=[False, False, False]
ground_truth_file=["visdrone_gt.txt", "visdrone_gt.txt", "visdrone_gt.txt"] # 放在Ground Truth里面


# DOTA_file=["MKD-Net-128.txt"] # 放在Prediction里面"Faster_RCNN.txt"
# DOTA_algorithm=["MKD-Net-128"] # 对应上面的文件的算法名称
# use_07_metric=[False]
# ground_truth_file=["dota_gt.txt"] # 放在Ground Truth里面
# 绘制出了各个类别的Precision-Recall曲线
draw_pr(visdrone_file,visdrone_algorithm,ground_truth_file,clas,use_07_metric)

pr.py

# -*- coding: UTF-8 -*-
#以faster和yolo为例
import numpy as np
import math
import matplotlib.pyplot as plt

def get_pr_data_map(prediction_file,ground_truth_file,cls_name,use_07_metric,ovthresh=0.4):
    with open(ground_truth_file, 'r') as f:####读取各个算法的gt.txt文件
        lines_gt = f.readlines()          #gt.txt的每一行,line_gt为一维数组
    with open(prediction_file, 'r') as f:####读取各个算法的pre.txt文件
        lines_pre = f.readlines()         #读取pre.txt的每一行,lines_gt为一维数组
            #根据txt文件中每一行的空格进行划分,splitlines_gt为二维数组,行为每一行的数据,列为每一行的数据划分
    splitlines_gt = [x.strip().split(' ') for x in lines_gt]
    #imagenames_gt为一维数组,当第二列的类别与遍历的类别数相同时,将第一列加入到imagenames_gt中
    imagenames_gt = [ x[0] for x in splitlines_gt ] 
    #print(imagenames_gt)    
    class_recs = {}       
#     #BB_gt为二维数组,当第二列的类别与遍历的类别数相同时,将第二列之后的加入到BB_gt中
    BB_gt=np.array([[math.ceil(float(z)) for z in x[2:]] for x in splitlines_gt])
    for i in range(len(imagenames_gt)):
        #如果类别中没有该遍历的类别,则更新
        if imagenames_gt[i] not in class_recs:
            class_recs.update({imagenames_gt[i]:{"bbox":[],"det":[],"difficult":[]}})
        #否则将BB_gt[i]插入到类别框信息中
        class_recs[imagenames_gt[i]]["bbox"].append(BB_gt[i])
        class_recs[imagenames_gt[i]]["det"].append(False)
        class_recs[imagenames_gt[i]]["difficult"].append(False)
            #npos为imagenames_gt的长度
    npos=len(imagenames_gt)
    splitlines_pre = [x.strip().split(' ') for x in lines_pre] 
    #image_ids为一维数组,当第二列的类别与遍历的类别数相同时,将第一列加入到image_ids中     
    image_ids = [x[0] for x in splitlines_pre] 
    #提取该类别预测框的置信度,当第二列的类别与遍历的类别数相同时,将第三列加入到confidence中
    confidence = np.array([float(x[2]) for x in splitlines_pre]) 
    #提取该类别预测框的bounging-boxes(二维数组),当第二列的类别与遍历的类别数相同时,将第三列之后的BBOX坐标插入
    BB_pre = np.array([[float(z) for z in x[3:]] for x in splitlines_pre]) 
    #按置信度大小将其索引从小到大排序(生成有顺序的一维数组)
    sorted_ind = np.argsort(-confidence)
    #按置信度大小将置信度从小到大排序(生成有顺序的一维数组)
    sorted_scores = np.sort(-confidence)
    #根据索引排序相应的bbox的坐标值(生成按置信度大小排列的二维数组)
    BB_pre = BB_pre[sorted_ind, :]
    #按置信度大小重新排列image_ids
    image_ids = [image_ids[x] for x in sorted_ind]
    nd = len(image_ids)
    tp = np.zeros(nd)
    fp = np.zeros(nd)
    pr_data_map=[]
    # print(nd)###############################################################################################################
    for d in range(nd):
        #按置信度顺序提取类别框
        #print(image_ids[d].get(image_ids[d]))
        # if class_recs[image_ids[d]] not in class_recs:
        # print(class_recs)
        if class_recs.get(image_ids[d]):
            #print(image_ids[d])
            R = class_recs[image_ids[d]]
            print(R)
            bb = np.array(BB_pre[d, :]).astype(float)#按置信度顺序提取bounding box
            ovmax = -np.inf#ovmax为负无穷大的数
            BBGT = np.array(R['bbox']).astype(float)#按置信度顺序提取groudtruth bbox

            #计算iou
            if BBGT.size > 0:
                ixmin = np.maximum(BBGT[:, 0], bb[0])
                iymin = np.maximum(BBGT[:, 1], bb[1])
                ixmax = np.minimum(BBGT[:, 2], bb[2])
                iymax = np.minimum(BBGT[:, 3], bb[3])
                iw = np.maximum(ixmax - ixmin + 1., 0.)
                ih = np.maximum(iymax - iymin + 1., 0.)
                inters = iw * ih

                uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
                       (BBGT[:, 2] - BBGT[:, 0] + 1.) *
                       (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)

                overlaps = inters / uni#重叠率
                ovmax = np.max(overlaps)#按重叠率的大小从大到小排序重叠率
                jmax = np.argmax(overlaps)#根据重叠率大小重新排序的索引

            if ovmax > ovthresh:#ovthresh=0.5阈值为0.5,判断tp和fp
                if not R['difficult'][jmax]:
                    if not R['det'][jmax]:
                        tp[d] = 1.
                        print(tp)
                        R['det'][jmax] = 1
                    else:
                        fp[d] = 1.
            else:
                fp[d] = 1.
        else:
            #print("*************************************************************"+str(d))
            fp[d]=1.


        
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)
    #rec=tp/正样本数
    rec = tp / float(npos)
    print(len(rec))
    # print(type(rec))
    #perc=tp/(tp+fp)
    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
    print(len(prec))
    # print(type(prec))
    ap = voc_ap(rec, prec, use_07_metric)
    print(ap)
    ###################################################
    
    # mAP_rec+=rec
    # print("mAP_rec:"+str(mAP_rec)+"\n")
    # mAP_prec+=prec
    # print("mAP_prec:"+str(mAP_prec)+"\n")
    # mAP+=ap
    # print("mAP:"+str(mAP)+"\n")
    ###################################################
    # if not pr_data.has_key(cls_name[cls_num]):
    # rec=np.array(rec)
    # prec=np.array(prec)
    # ap=np.array(ap)
    pr_data_map=np.array([rec.tolist(),prec.tolist()])#,ap.tolist()]
    pr_data_map=pr_data_map.tolist()

    #print(pr_data_map)
    return pr_data_map 


def get_pr_data(prediction_file,ground_truth_file,cls_name,use_07_metric,ovthresh=0.5):
    with open(ground_truth_file, 'r') as f:####读取各个算法的gt.txt文件
        lines_gt = f.readlines()          #gt.txt的每一行,line_gt为一维数组
    with open(prediction_file, 'r') as f:####读取各个算法的pre.txt文件
        lines_pre = f.readlines()         #读取pre.txt的每一行,lines_gt为一维数组

    pr_data={}
    
    for cls_num in range(len(cls_name)): #遍历每个类别数
    #ground_truth  
    	#根据txt文件中每一行的空格进行划分,splitlines_gt为二维数组,行为每一行的数据,列为每一行的数据划分
        splitlines_gt = [x.strip().split(' ') for x in lines_gt]
        # print(splitlines_gt)
        #imagenames_gt为一维数组,当第二列的类别与遍历的类别数相同时,将第一列加入到imagenames_gt中
        # imagenames_gt = []
        # for x in splitlines_gt:
        #     if int(x[1]) - 1 == cls_num:
        #         imagenames_gt = x[0]
        imagenames_gt = [ x[0] for x in splitlines_gt if int(x[1]) - 1 ==cls_num]
                # print(imagenames_gt)   



        #BB_gt为二维数组,当第二列的类别与遍历的类别数相同时,将第二列之后的加入到BB_gt中
        BB_gt=np.array([[math.ceil(float(z)) for z in x[2:]] for x in splitlines_gt if int(x[1]) - 1 ==cls_num])
        #创建class_recs数组
        class_recs = {}
        for i in range(len(imagenames_gt)):
            # import pdb
            # pdb.set_trace()
            #如果类别中没有该遍历的类别,则更新
            if imagenames_gt[i] not in class_recs:
                
                class_recs.update({imagenames_gt[i]:{"bbox":[],"det":[],"difficult":[]}})
            #否则将BB_gt[i]插入到类别框信息中
            class_recs[imagenames_gt[i]]["bbox"].append(BB_gt[i])
            class_recs[imagenames_gt[i]]["det"].append(False)
            class_recs[imagenames_gt[i]]["difficult"].append(False)
            # print(class_recs[imagenames_gt])
        #npos为imagenames_gt的长度
        npos=len(imagenames_gt)
        # print("##################")
        # print(class_recs)
    
    #prediction 
    	#对预测框的每一行根据空格划分,splitlines_pre为二维数组   
        splitlines_pre = [x.strip().split(' ') for x in lines_pre] 
        # print(splitlines_pre)
        #image_ids为一维数组,当第二列的类别与遍历的类别数相同时,将第一列加入到image_ids中     
        image_ids = [ x[0] for x in splitlines_pre if int(x[1]) - 1 ==cls_num] 
        # print(cls_num)
		#提取该类别预测框的置信度,当第二列的类别与遍历的类别数相同时,将第三列加入到confidence中
        confidence = np.array([float(x[2]) for x in splitlines_pre if int(x[1]) - 1 ==cls_num]) 
        # print("********", confidence)
        #提取该类别预测框的bounging-boxes(二维数组),当第二列的类别与遍历的类别数相同时,将第三列之后的BBOX坐标插入
        BB_pre = np.array([[float(z) for z in x[3:]] for x in splitlines_pre if int(x[1]) - 1 ==cls_num]) 
        #按置信度大小将其索引从小到大排序(生成有顺序的一维数组)
        sorted_ind = np.argsort(-confidence)
        # print(sorted_ind,"*************")
        #按置信度大小将置信度从小到大排序(生成有顺序的一维数组)
        sorted_scores = np.sort(-confidence)
        #根据索引排序相应的bbox的坐标值(生成按置信度大小排列的二维数组)
        BB_pre = BB_pre[sorted_ind, :]
        #按置信度大小重新排列image_ids
        image_ids = [image_ids[x] for x in sorted_ind]
        #该类别的预测框的数量
        nd = len(image_ids)
        tp = np.zeros(nd)#将长度为nd数组置0
        fp = np.zeros(nd)

        # print(nd)###############################################################################################################
        for d in range(nd):
        	#按置信度顺序提取类别框
        	#print(image_ids[d].get(image_ids[d]))
            # if class_recs[image_ids[d]] not in class_recs:
            
            if class_recs.get(image_ids[d]):
                
                # print(image_ids[d])
                R = class_recs[image_ids[d]]
                #print(R)
                bb = np.array(BB_pre[d, :]).astype(float)#按置信度顺序提取bounding box
                ovmax = -np.inf#ovmax为负无穷大的数
                BBGT = np.array(R['bbox']).astype(float)#按置信度顺序提取groudtruth bbox

	            #计算iou
                if BBGT.size > 0:
                    ixmin = np.maximum(BBGT[:, 0], bb[0])
                    iymin = np.maximum(BBGT[:, 1], bb[1])
                    ixmax = np.minimum(BBGT[:, 2], bb[2])
                    iymax = np.minimum(BBGT[:, 3], bb[3])
                    # ixmin = max(BBGT[:, 0], bb[0])
                    # iymin = max(BBGT[:, 1], bb[1])
                    # ixmax = min(BBGT[:, 2], bb[2])
                    # iymax = min(BBGT[:, 3], bb[3])
                    # import pdb
                    # pdb.set_trace()
                    # if ixmin >= ixmax or iymax <= iymin:
                    #     return 0
                    # else:
                    # S1 = (BBGT[2]-BBGT[0])*(BBGT[3]-BBGT[1])
                    # S2 = (bb[2]-bb[0])*(bb[3]-bb[1])
                    iw = np.maximum(ixmax - ixmin + 1., 0.)
                    ih = np.maximum(iymax - iymin + 1., 0.)
                    inters = iw * ih

                    uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
                           (BBGT[:, 2] - BBGT[:, 0] + 1.) *
                           (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)

                    overlaps = inters / uni#重叠率
                    ovmax = np.max(overlaps)#按重叠率的大小从大到小排序重叠率
                    jmax = np.argmax(overlaps)#根据重叠率大小重新排序的索引

                if ovmax > ovthresh:#ovthresh=0.5阈值为0.5,判断tp和fp
                    if not R['difficult'][jmax]:
                        if not R['det'][jmax]:
                            tp[d] = 1.
                            R['det'][jmax] = 1
                        else:
                            fp[d] = 1.
                else:
                    fp[d] = 1.
            else:
                #print("*************************************************************"+str(d))
                fp[d]=1.


        # import pdb
        # pdb.set_trace()    
        print(cls_name[cls_num],  np.sum(tp), np.sum(fp), npos)
        fp = np.cumsum(fp)
        tp = np.cumsum(tp)
        #rec=tp/正样本数
        rec = tp / float(npos)
        #print(rec)
        # print(type(rec))
        #perc=tp/(tp+fp)
        prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
        #?print(prec)
        # print(type(prec))
        ap = voc_ap(rec, prec, use_07_metric)
        #print(ap)
        ###################################################
        
        # mAP_rec+=rec
        # print("mAP_rec:"+str(mAP_rec)+"\n")
        # mAP_prec+=prec
        # print("mAP_prec:"+str(mAP_prec)+"\n")
        # mAP+=ap
        # print("mAP:"+str(mAP)+"\n")
        ###################################################
        if cls_name[cls_num] not in pr_data:
            pr_data.update({cls_name[cls_num]:[rec,prec,ap]})
    return pr_data #,mAP/7,np.array(mAP_rec/7),np.array(mAP_prec/7)

def voc_ap(rec, prec, use_07_metric):#由于use_07_metric=true时计算结果于实际更接近
    #计算ap,use_07_metric=true,# 2010年以前按recall等间隔取11个不同点处的精度值做平均(0., 0.1, 0.2, …, 0.9, 1.0)
    if use_07_metric:
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):#([0.0,0.1,0.2,0.3,...,1.0])
            #print(11111111111111111111111111)
            if np.sum(rec >= t) == 0:
                p = 0
            else:
                p = np.max(prec[rec >= t])
            ap = ap + p / 11.
    #use_07_metric=false # 2010年以后取所有不同的recall对应的点处的精度值做平均
    else:
        mrec = np.concatenate(([0.], rec, [1.]))
        mpre = np.concatenate(([0.], prec, [0.]))
        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
        i = np.where(mrec[1:] != mrec[:-1])[0]
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

def draw_pr(prediction_file,prediction_algorithm,ground_truth_file,cls,use_07_metric):
    # for i in range(len(prediction_file)):
    #     pr_data=get_pr_data("Prediction/"+prediction_file[i],"Ground_Truth/"+ground_truth_file,cls)
    #     for cls_name in pr_data:
    #         plt.plot(pr_data[cls_name][0],pr_data[cls_name][1],label=prediction_algorithm[i]+' mAP='+str(round(pr_data[cls_name][2],3)))
    #         title='PR Curve of '+cls_name
    #         plt.title(title)
    #         plt.xlabel('Recall')
    #         plt.ylabel('Precision')
    #         plt.ylim([0.0, 1.0])
    #         plt.xlim([0.0, 1.0])
    #         plt.grid(ls='-.')
    #         plt.legend()
    #         plt.savefig("Images/"+title+'.png', dpi=300)
    #         plt.show()
    #         plt.close()
    # for i in range(len(prediction_file)):
    #     pr_data=get_pr_data("Prediction/"+prediction_file[i],"Ground_Truth/"+ground_truth_file,cls)

    #画多个算法多个类别的pr曲线
    MAP=[0,0,0,0,0,0]#用于统计map
    for cls_name in cls:
        for i in range(len(prediction_file)):
            
            pr_data=get_pr_data("Prediction/"+prediction_file[i],"Ground_Truth/"+ground_truth_file[i],cls,use_07_metric[i])
            # import pdb
            # pdb.set_trace()
            MAP[i]=MAP[i]+pr_data[cls_name][2]
            # print(pr_data)

            if i==0:
                plt.plot(pr_data[cls_name][0],pr_data[cls_name][1],label=prediction_algorithm[i]+' AP='+str(round(pr_data[cls_name][2],3)),color="#054E9F")
            elif i==1:
                plt.plot(pr_data[cls_name][0],pr_data[cls_name][1],label=prediction_algorithm[i]+' AP='+str(round(pr_data[cls_name][2],3)),color="#FFA500")
            elif i==2:
                plt.plot(pr_data[cls_name][0],pr_data[cls_name][1],label=prediction_algorithm[i]+' AP='+str(round(pr_data[cls_name][2],3)),color="#B0C4DE")
            elif i==3:
                plt.plot(pr_data[cls_name][0],pr_data[cls_name][1],label=prediction_algorithm[i]+' AP='+str(round(pr_data[cls_name][2],3)),color="#008000")
            elif i==4:
                plt.plot(pr_data[cls_name][0],pr_data[cls_name][1],label=prediction_algorithm[i]+' AP='+str(round(pr_data[cls_name][2],3)),color="#BA55D3")
            elif i==5:
                plt.plot(pr_data[cls_name][0],pr_data[cls_name][1],label=prediction_algorithm[i]+' AP='+str(round(pr_data[cls_name][2],3)),color="#FF0000")
            # plt.plot(pr_data[cls_name][0],pr_data[cls_name][1],label=prediction_algorithm[i]+' AP='+str(round(pr_data[cls_name][2],3)),color="#054E9F")
            title=cls_name+' PR Curve'
            plt.title(title,fontsize=15)
            plt.xlabel('Recall',fontsize=15)
            plt.ylabel('Precision',fontsize=15)
            plt.ylim([0.0, 1.0])
            plt.xlim([0.0, 1.0])
            plt.grid(ls='-.')
            plt.legend()
            plt.savefig("Image_PR/"+title+'.png', dpi=600)
            print("save....ok!!!")
        plt.show()
        plt.close() 

    #画各个算法map的pr曲线
    # for i in range(len(prediction_file)): 
    #     pr_data_map=get_pr_data_map("Prediction/"+prediction_file[i],"Ground_Truth/"+ground_truth_file[i],cls,use_07_metric[i])
    #     if i==0:
    #         plt.plot(pr_data_map[0],pr_data_map[1],label=prediction_algorithm[i]+' mAP='+str(round(MAP[i]/len(cls),3)),color="#054E9F")
    #     elif i==1:
    #         plt.plot(pr_data_map[0],pr_data_map[1],label=prediction_algorithm[i]+' mAP='+str(round(MAP[i]/len(cls),3)),color="#FFA500")
    #     elif i==2:
    #         plt.plot(pr_data_map[0],pr_data_map[1],label=prediction_algorithm[i]+' mAP='+str(round(MAP[i]/len(cls),3)),color="#B0C4DE")
    #     elif i==3:
    #         plt.plot(pr_data_map[0],pr_data_map[1],label=prediction_algorithm[i]+' mAP='+str(round(MAP[i]/len(cls),3)),color="#008000")
    #     elif i==4:
    #         plt.plot(pr_data_map[0],pr_data_map[1],label=prediction_algorithm[i]+' mAP='+str(round(MAP[i]/len(cls),3)),color="#BA55D3")
    #     elif i==5:
    #         plt.plot(pr_data_map[0],pr_data_map[1],label=prediction_algorithm[i]+' mAP='+str(round(MAP[i]/len(cls),3)),color="#FF0000")
    #     #plt.plot(pr_data_map[0],pr_data_map[1],label=prediction_algorithm[i]+' mAP='+str(round(MAP[i]/len(cls),3)),color="#054E9F")
    #     title='Precision-Recall Curve'
    #     plt.title(title,fontsize=15)
    #     plt.xlabel('Recall',fontsize=15)
    #     plt.ylabel('Precision',fontsize=15)
    #     plt.ylim([0.0, 1.0])
    #     plt.xlim([0.0, 1.0])
    #     plt.grid(ls='-.')
    #     plt.legend()
    #     plt.savefig("Image_PR/"+title+'.png', dpi=600)
    #     print("save....ok!!!")
    # plt.show()
    # plt.close() 

你可能感兴趣的:(#,航拍,&,人脸数据集,计算机视觉,python,目标检测)