【JSON格式划分数据】json格式数据划分训练集、验证集

json格式数据划分训练集、验证集


用法:可用于COCO格式的数据集的划分,还可添加测试集或者多个训练+验证集。根据自身需求,修改一下代码就好啦~

split_train_val.py
执行python split_train_val.py --img_path xxx --json_file *.json --output xxx即可。

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# 将一个文件夹下图片按比例分在两个文件夹下
import os
import random
from shutil import copy2
import json
import argparse


def main(args):

    all_data = os.listdir(args.img_path)  # (图片文件夹)

    random.seed(1)  
    random.shuffle(all_data) # 第一次打乱
    all_data_img = []
    for i in all_data:
        if i.endswith(".jpg"):
            all_data_img.append(i)
    num_all_data = len(all_data_img)
    print("num_all_data: " + str(num_all_data))
    index_list = list(range(num_all_data))
    random.seed(2) 
    random.shuffle(index_list) # 第二次打乱
    num = 0

    trainDir = os.path.join(args.img_path, "train")  # (将训练集放在这个文件夹下)
    if not os.path.exists(trainDir):
        os.mkdir(trainDir)

    validDir = s.path.join(args.img_path, "val")  # (将验证集放在这个文件夹下)
    if not os.path.exists(validDir):
        os.mkdir(validDir)

    train_list = []
    val_list = []
    for i in index_list:
        fileName = os.path.join(args.img_path, all_data_img[i])
        if num < num_all_data * 0.5: # 这里可是设置train,val的比例
            train_list.append(all_data_img[i].split('.')[0])
            copy2(fileName, os.path.join(trainDir, all_data_img[i]))
        else:
            val_list.append(all_data_img[i].split('.')[0])
            copy2(fileName, os.path.join(validDir, all_data_img[i]))
        num += 1

    print("train_nums", len(train_list))
    print("val_nums", len(val_list))

    data = json.load(open(args.json_file))
    train_json_dict = {
        "images":[],
        "annotations":[], 
        "categories":[],
        "type": "instances"
    }
    val_json_dict = {
        "images":[],
        "annotations":[], 
        "categories":[],
        "type": "instances"
    }

    # images
    for i in data['images']:
        if i['id'] in train_list:
            train_json_dict['images'].append(i)
        if i['id'] in val_list:
            val_json_dict['images'].append(i)
        
    # annotations 
    for j in data['annotations']:
        j['category_id'] -= 1 # 类别从0开始
        if j["image_id"] in train_list:
            train_json_dict['annotations'].append(j)
        if j["image_id"] in val_list:
            val_json_dict['annotations'].append(j)

    # categories ,类别从0开始
    for k in data['categories']:
        k['id'] -= 1
        train_json_dict['categories'].append(k)
        val_json_dict['categories'].append(k)

    with open(os.path.join(args.output, "train.json"), "w") as f:
        json.dump(train_json_dict, f, indent=2)

    with open(os.path.join(args.output, "val.json"), "w") as f:
        json.dump(val_json_dict, f, indent=2)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Start convert.')
    parser.add_argument('--img_path', type=str, help='raw images path')# json文件路径
    parser.add_argument('--json_file', type=str, help='json file path')# json文件路径
    parser.add_argument('--output', type=str, help='output path', default='')# 输出的 txt 文件路径
    args = parser.parse_args()
    main(args)

你可能感兴趣的:(#,实验代码,计算机视觉,深度学习,python)