VOC数据集生成训练和测试集列表

在标准的VOC数据格式数据集中,拆分训练集和测试集,并在ImageSet/Main中生成train.txt和test.txt,同时生成yolo使用的train_img_path.txt和test_img_path.txt文件。可以仅抽选一定比例。

VOC数据即目录如下:

    ├── Annotations [833 entries exceeds filelimit, not opening dir]
    ├── ImageSets
    │   └── Main
    │       ├── test_img_path.txt
    │       ├── test.txt
    │       ├── train_img_path.txt
    │       └── train.txt
    └── JPEGImages [833 entries exceeds filelimit, not opening dir]
 

python代码如下:

# -*- coding: utf-8 -*-
"""
拆分训练集和测试集,并在ImageSet/Main中生成train.txt和test.txt
生成yolo使用的train_img_path.txt和test_img_path.txt文件
可以仅抽选一定比例
"""

import os
import xml.etree.ElementTree as ET
import argparse


def split_train_test_set(root_dir, target_dir, test_fraction=0.1):
    """
    split dataset into test and train subsets. The script will create train.txt and test.txt
    that contain a list of files for VOC dataset, and train_img_path.txt/test_img_path.txt
    for yolo training. All the four files will be created in ImageSet/Main directory.

    :param test_fraction: the fraction of test set
    :return: None
    """
    train_txt = open(os.path.join(target_dir, "train.txt"), "w")
    test_txt = open(os.path.join(target_dir, "test.txt"), "w")
    yolo_train_img_path = open(os.path.join(target_dir, 'train_img_path.txt'), 'w')
    yolo_test_img_path = open(os.path.join(target_dir, 'test_img_path.txt'), 'w')
    xml_path = os.path.join(root_dir, 'Annotations')
    '''从已有标注数据中抽选'''
    select_percent = 1
    # 测试集合占比
    # test_percent = 0.01
    xml_num = len(os.listdir(xml_path))
    print('=========based on labels(xml): {} ============'.format(xml_num))

    train_num = 0
    test_num = 0
    total_processed_num = 0
    if test_fraction <= 1e-10:
        sample_interval = xml_num + 1
    else:
        sample_interval = int(1/test_fraction)
    select_interval = int(1/select_percent)
    print('select sample interval is: {}'.format(select_interval))
    print('test sample interval is: {}'.format(sample_interval))

    files = os.listdir(xml_path)
    for file in files:
        total_processed_num = total_processed_num + 1
        if total_processed_num % select_interval != 0:
            if total_processed_num % 100 == 0:
                print('completed {} / {}'.format(total_processed_num, xml_num))
            continue

        img_name = file.split('.')[0]
        tree = ET.parse(os.path.join(xml_path, file))
        root = tree.getroot()
        file_name = root.find('filename').text

        img_path = os.path.join(root_dir, 'JPEGImages', file_name)

        if total_processed_num % (sample_interval * select_interval) == 0:
            test_txt.write(str(img_name) + '\n')
            yolo_test_img_path.write(img_path + '\n')
            test_num = test_num + 1
        elif total_processed_num % select_interval == 0:
            train_txt.write(str(img_name) + '\n')
            yolo_train_img_path.write(img_path+'\n')
            train_num = train_num + 1
        if total_processed_num % 100 == 0:
            print('completed {} / {}'.format(total_processed_num, xml_num))
    print('total:', total_processed_num, 'train:', train_num, 'test:', test_num)
    train_txt.close()
    test_txt.close()
    yolo_train_img_path.close()
    yolo_test_img_path.close()


if __name__ == '__main__':
    root_dir = os.path.join(os.getenv("Home"), "PycharmProjects/data_tool/data/voc")
    parser = argparse.ArgumentParser(description="define all the file paths")
    parser.add_argument("--target_dir", type=str, help="image directory",
                        default=os.path.join(root_dir, 'ImageSets', 'Main'))
    args = parser.parse_args()
    target_dir = args.target_dir
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    split_train_test_set(root_dir, target_dir, 0.1)

 

 

你可能感兴趣的:(Python)