kitti数据集转voc再转coco

合并类别

此次博主为数据集设置3个类别, ‘Car’,’Cyclist’,’Pedestrian’,只不过标注信息中还有其他类型的车和人,直接略过有点浪费,博主希望将 ‘Van’, ‘Truck’, ‘Tram’ 合并到 ‘Car’ 类别中去,将 ‘Person_sitting’ 合并到 ‘Pedestrian’ 类别中去(‘Misc’ 和 ‘Dontcare’ 这两类直接忽略)。参考:前辈
对原博主的代码进行了一些适当的修改

# modify_annotations_txt.py
'''
    PASCAL VOC数据集总共20个类别,如果用于特定场景,20个类别确实多了。此次博主为数据集设置3个类别,
    ‘Car’,’Cyclist’,’Pedestrian’,只不过标注信息中还有其他类型的车和人,直接略过有点浪费,
    博主希望将 ‘Van’, ‘Truck’, ‘Tram’ 合并到 ‘Car’ 类别中去,将 ‘Person_sitting’ 合并到 ‘Pedestrian’ 类别中去
    (‘Misc’ 和 ‘Dontcare’ 这两类直接忽略)。
'''

# 合并类别
import glob
import os

import tqdm

# txt_list = glob.glob('G:/A_Data/kitti_voc/label/')  # 存储Labels文件夹所有txt文件路径
path = 'G:/A_Data/kitti_voc/label/'
txt_list = os.listdir(path)
# txt_list = os.listdir(path='G:/A_Data/kitti_voc/label/')
def show_category(txt_list):
    category_list = []
    for item in txt_list:
        item = path + item
        try:
            with open(item) as tdf:
                for each_line in tdf:
                    labeldata = each_line.strip().split(' ')  # 去掉前后多余的字符并把其分开
                    category_list.append(labeldata[0])  # 只要第一个字段,即类别
        except IOError as ioerr:
            print('File error:'+str(ioerr))
    print(set(category_list))  # 输出集合

def merge(line):
    each_line=''
    for i in range(len(line)):
        if i != (len(line)-1):
            each_line=each_line+line[i]+' '
        else:
            each_line=each_line+line[i]  # 最后一条字段后面不加空格
    each_line=each_line+'\n'
    return (each_line)

# print('before modify categories are:\n')
show_category(txt_list)

for item in tqdm.tqdm(txt_list):
    new_txt=[]
    item = path + item
    try:
        with open(item, 'r') as r_tdf:
            for each_line in r_tdf:
                labeldata = each_line.strip().split(' ')
                if labeldata[0] in ['Truck','Van','Tram']:  # 合并汽车类
                    labeldata[0] = labeldata[0].replace(labeldata[0],'Car')
                if labeldata[0] == 'Person_sitting': # 合并行人类
                    labeldata[0] = labeldata[0].replace(labeldata[0],'Pedestrian')
                if labeldata[0] == 'DontCare':  # 忽略Dontcare类
                    continue
                if labeldata[0] == 'Misc':  # 忽略Misc类
                    continue
                new_txt.append(merge(labeldata)) # 重新写入新的txt文件
        with open(item,'w+') as w_tdf:  # w+是打开原文件将内容删除,另写新内容进去
            for temp in new_txt:
                w_tdf.write(temp)
    except IOError as ioerr:
        print('File error:'+str(ioerr))

# print('\nafter modify categories are:\n')
show_category(txt_list)

txt转xml

把路径修改正确就可以了。

# txt_to_xml.py
# encoding:utf-8
# 根据一个给定的XML Schema,使用DOM树的形式从空白文件生成一个XML
from xml.dom.minidom import Document
import cv2
import os

import tqdm


def generate_xml(name,split_lines,img_size,class_ind):
    doc = Document()  # 创建DOM文档对象

    annotation = doc.createElement('annotation')
    doc.appendChild(annotation)

    title = doc.createElement('folder')
    title_text = doc.createTextNode('KITTI')
    title.appendChild(title_text)
    annotation.appendChild(title)

    img_name=name+'.png'

    title = doc.createElement('filename')
    title_text = doc.createTextNode(img_name)
    title.appendChild(title_text)
    annotation.appendChild(title)

    source = doc.createElement('source')
    annotation.appendChild(source)

    title = doc.createElement('database')
    title_text = doc.createTextNode('The KITTI Database')
    title.appendChild(title_text)
    source.appendChild(title)

    title = doc.createElement('annotation')
    title_text = doc.createTextNode('KITTI')
    title.appendChild(title_text)
    source.appendChild(title)

    size = doc.createElement('size')
    annotation.appendChild(size)

    title = doc.createElement('width')
    title_text = doc.createTextNode(str(img_size[1]))
    title.appendChild(title_text)
    size.appendChild(title)

    title = doc.createElement('height')
    title_text = doc.createTextNode(str(img_size[0]))
    title.appendChild(title_text)
    size.appendChild(title)

    title = doc.createElement('depth')
    title_text = doc.createTextNode(str(img_size[2]))
    title.appendChild(title_text)
    size.appendChild(title)

    for split_line in split_lines:
        line=split_line.strip().split()
        if line[0] in class_ind:
            object = doc.createElement('object')
            annotation.appendChild(object)

            title = doc.createElement('name')
            title_text = doc.createTextNode(line[0])
            title.appendChild(title_text)
            object.appendChild(title)

            bndbox = doc.createElement('bndbox')
            object.appendChild(bndbox)
            title = doc.createElement('xmin')
            title_text = doc.createTextNode(str(int(float(line[4]))))
            title.appendChild(title_text)
            bndbox.appendChild(title)
            title = doc.createElement('ymin')
            title_text = doc.createTextNode(str(int(float(line[5]))))
            title.appendChild(title_text)
            bndbox.appendChild(title)
            title = doc.createElement('xmax')
            title_text = doc.createTextNode(str(int(float(line[6]))))
            title.appendChild(title_text)
            bndbox.appendChild(title)
            title = doc.createElement('ymax')
            title_text = doc.createTextNode(str(int(float(line[7]))))
            title.appendChild(title_text)
            bndbox.appendChild(title)

    # 将DOM对象doc写入文件
    f = open('G:/A_Data/kitti_voc/Annotations/'+name+'.xml', 'w')
    f.write(doc.toprettyxml(indent=''))
    f.close()

if __name__ == '__main__':
    class_ind = ('Pedestrian', 'Car', 'Cyclist')
    # cur_dir = os.getcwd()
    cur_dir = 'G:/A_Data/kitti_voc/'
    labels_dir = os.path.join(cur_dir, 'label')
    for parent, dirnames, filenames in os.walk(labels_dir):  # 分别得到根目录,子目录和根目录下文件
        for file_name in tqdm.tqdm(filenames):
            full_path=os.path.join(parent, file_name)  # 获取文件全路径
            f=open(full_path)
            split_lines = f.readlines()
            name= file_name[:-4] # 后四位是扩展名.txt,只取前面的文件名
            img_name = name+'.png'
            img_path = os.path.join('E:/A____paper/person/data/kitti/training/image_2/', img_name)  # 路径需要自行修改
            img_size = cv2.imread(img_path).shape
            generate_xml(name, split_lines, img_size, class_ind)
print('all txts has converted into xmls')

划分训练集和验证集

# create_train_test_txt.py
# encoding:utf-8
import pdb
import glob
import os
import random
import math

def get_sample_value(txt_name, category_name):
    label_path = 'G:/A_Data/kitti_voc/label/'
    txt_path = label_path + txt_name+'.txt'
    try:
        with open(txt_path) as r_tdf:
            if category_name in r_tdf.read():
                return ' 1'
            else:
                return '-1'
    except IOError as ioerr:
        print('File error:'+str(ioerr))

# txt_list_path = glob.glob('./Labels/*.txt')
txt_list_path = os.listdir(path='G:/A_Data/kitti_voc/label/')
txt_list = []

for item in txt_list_path:
    temp1,temp2 = os.path.splitext(os.path.basename(item))
    txt_list.append(temp1)
txt_list.sort()
print(txt_list, end = '\n\n')

# 有博客建议train:val:test=8:1:1,先尝试用一下
num_trainval = random.sample(txt_list, math.floor(len(txt_list)*9/10.0)) # 可修改百分比
num_trainval.sort()
print(num_trainval, end = '\n\n')

num_train = random.sample(num_trainval,math.floor(len(num_trainval)*8/9.0)) # 可修改百分比
num_train.sort()
print(num_train, end = '\n\n')

num_val = list(set(num_trainval).difference(set(num_train)))
num_val.sort()
print(num_val, end = '\n\n')

num_test = list(set(txt_list).difference(set(num_trainval)))
num_test.sort()
print(num_test, end = '\n\n')

# pdb.set_trace()

Main_path = 'G:/A_Data/kitti_voc/ImageSets/Main/'
train_test_name = ['trainval', 'train', 'val', 'test']
category_name = ['Car', 'Pedestrian', 'Cyclist']

# 循环写trainvl train val test
for item_train_test_name in train_test_name:
    list_name = 'num_'
    list_name += item_train_test_name
    train_test_txt_name = Main_path + item_train_test_name + '.txt'
    try:
        # 写单个文件
        with open(train_test_txt_name, 'w') as w_tdf:
            # 一行一行写
            for item in eval(list_name):
                w_tdf.write(item+'\n')
        # 循环写Car Pedestrian Cyclist
        for item_category_name in category_name:
            category_txt_name = Main_path + item_category_name + '_' + item_train_test_name + '.txt'
            with open(category_txt_name, 'w') as w_tdf:
                # 一行一行写
                for item in eval(list_name):
                    w_tdf.write(item+' '+ get_sample_value(item, item_category_name)+'\n')
    except IOError as ioerr:
        print('File error:'+str(ioerr))

VOC转COCO

生成训练和验证的json,和相应的图片文件夹

# -*- coding=utf-8 -*-
#!/usr/bin/python

import sys
import os
import shutil
import numpy as np
import json
import xml.etree.ElementTree as ET

# 检测框的ID起始值
import tqdm

START_BOUNDING_BOX_ID = 1
# 类别列表无必要预先创建,程序中会根据所有图像中包含的ID来创建并更新
PRE_DEFINE_CATEGORIES = {}


def get(root, name):
    vars = root.findall(name)
    return vars


def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise NotImplementedError('Can not find %s in %s.'%(name, root.tag))
    if length > 0 and len(vars) != length:
        raise NotImplementedError('The size of %s is supposed to be %d, but is %d.'%(name, length, len(vars)))
    if length == 1:
        vars = vars[0]
    return vars


# 得到图片唯一标识号
def get_filename_as_int(filename):
    try:
        filename = os.path.splitext(filename)[0]
        return int(filename)
    except:
        raise NotImplementedError('Filename %s is supposed to be an integer.'%(filename))


def convert(xml_list, xml_dir, json_file):
    '''
    :param xml_list: 需要转换的XML文件列表
    :param xml_dir: XML的存储文件夹
    :param json_file: 导出json文件的路径
    :return: None
    '''
    list_fp = xml_list
    # 标注基本结构
    json_dict = {"images":[],
                 "type": "instances",
                 "annotations": [],
                 "categories": []}
    categories = PRE_DEFINE_CATEGORIES
    bnd_id = START_BOUNDING_BOX_ID
    for line in list_fp:
        line = line.strip()
        # print("buddy~ Processing {}".format(line))
        # 解析XML
        xml_f = os.path.join(xml_dir, line)
        tree = ET.parse(xml_f)
        root = tree.getroot()
        path = get(root, 'path')
        # 取出图片名字
        if len(path) == 1:
            filename = os.path.basename(path[0].text)
        elif len(path) == 0:
            filename = get_and_check(root, 'filename', 1).text
        else:
            raise NotImplementedError('%d paths found in %s'%(len(path), line))
        ## The filename must be a number
        image_id = get_filename_as_int(filename)  # 图片ID
        size = get_and_check(root, 'size', 1)
        # 图片的基本信息
        width = int(get_and_check(size, 'width', 1).text)
        height = int(get_and_check(size, 'height', 1).text)
        image = {'file_name': filename,
                 'height': height,
                 'width': width,
                 'id':image_id}
        json_dict['images'].append(image)
        ## Cruuently we do not support segmentation
        #  segmented = get_and_check(root, 'segmented', 1).text
        #  assert segmented == '0'
        # 处理每个标注的检测框
        for obj in get(root, 'object'):
            # 取出检测框类别名称
            category = get_and_check(obj, 'name', 1).text
            # 更新类别ID字典
            if category not in categories:
                new_id = len(categories)
                categories[category] = new_id
            category_id = categories[category]
            bndbox = get_and_check(obj, 'bndbox', 1)
            xmin = int(get_and_check(bndbox, 'xmin', 1).text)
            ymin = int(get_and_check(bndbox, 'ymin', 1).text)
            xmax = int(get_and_check(bndbox, 'xmax', 1).text)
            ymax = int(get_and_check(bndbox, 'ymax', 1).text)
            assert(xmax > xmin)
            assert(ymax > ymin)
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            annotation = dict()
            annotation['area'] = o_width*o_height
            annotation['iscrowd'] = 0
            annotation['image_id'] = image_id
            annotation['bbox'] = [xmin, ymin, o_width, o_height]
            annotation['category_id'] = category_id
            annotation['id'] = bnd_id
            annotation['ignore'] = 0
            # 设置分割数据,点的顺序为逆时针方向
            annotation['segmentation'] = [[xmin,ymin,xmin,ymax,xmax,ymax,xmax,ymin]]

            json_dict['annotations'].append(annotation)
            bnd_id = bnd_id + 1

    # 写入类别ID字典
    for cate, cid in categories.items():
        cat = {'supercategory': 'none', 'id': cid, 'name': cate}
        json_dict['categories'].append(cat)
    # 导出到json
    json_fp = open(json_file, 'w')
    json_str = json.dumps(json_dict)
    json_fp.write(json_str)
    json_fp.close()


if __name__ == '__main__':
    # root_path = os.getcwd()
    root_path = 'G:/A_Data/kitti_voc/'
    ori_path = 'E:/A____paper/person/data/kitti/training/image_2/'
    xml_dir = os.path.join(root_path, 'Annotations')

    xml_labels = os.listdir(os.path.join(root_path, 'Annotations'))
    np.random.shuffle(xml_labels)
    split_point = int(len(xml_labels)/10)     # 这个是直接把前10%的图片作为验证集

    # validation data
    xml_list = xml_labels[0:split_point]
    json_file = 'G:/A_Data/kitti_voc/instances_val2014.json'
    convert(xml_list, xml_dir, json_file)
    for xml_file in tqdm.tqdm(xml_list):
        img_name = xml_file[:-4] + '.jpg'
        shutil.copy(os.path.join(ori_path, img_name),
                    os.path.join(root_path, 'val_img', img_name))
    # train data
    xml_list = xml_labels[split_point:]
    json_file = 'G:/A_Data/kitti_voc/instances_train2014.json'
    convert(xml_list, xml_dir, json_file)
    for xml_file in tqdm.tqdm(xml_list):
        img_name = xml_file[:-4] + '.jpg'
        shutil.copy(os.path.join(ori_path, img_name),
                    os.path.join(root_path, 'train_img', img_name))


你可能感兴趣的:(自动驾驶,人工智能,机器学习)