把json的训练文件划分为训练和验证两个json

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# 将一个文件夹下图片按比例分在两个文件夹下
import os
import random
from shutil import copy2
import json
import argparse


def main(args):
    all_data = os.listdir(args.img_path)  # (图片文件夹)

    random.seed(1)
    random.shuffle(all_data)  # 第一次打乱
    all_data_img = []
    for i in all_data:
        if i.endswith(".jpg"):
            all_data_img.append(i)
    num_all_data = len(all_data_img)
    print("num_all_data: " + str(num_all_data))
    index_list = list(range(num_all_data))
    random.seed(2)
    random.shuffle(index_list)  # 第二次打乱
    num = 0

    trainDir = os.path.join(args.img_path2, "train")  # (将训练集放在这个文件夹下)
    if not os.path.exists(trainDir):
        os.mkdir(trainDir)

    validDir = os.path.join(args.img_path2, "val")  # (将验证集放在这个文件夹下)
    if not os.path.exists(validDir):
        os.mkdir(validDir)



    data = json.load(open(args.json_file, 'r', encoding='utf-8'))
    train_json_dict = {
        "images": [],
        "annotations": [],
        "categories": [],
        "type": "instances"
    }
    val_json_dict = {
        "images": [],
        "annotations": [],
        "categories": [],
        "type": "instances"
    }

    train_list = []
    val_list = []
    for i in index_list:
        fileName = os.path.join(args.img_path, all_data_img[i])
        if num < num_all_data * 0.8:  # 这里可是设置train,val的比例
            for k in data['images']:
                name = all_data_img[i]
                if k['file_name'] == name:
                    indx = k['id']
                    train_list.append(indx)
                    copy2(fileName, os.path.join(trainDir, all_data_img[i]))
                    break
        else:
            for k in data['images']:
                name = all_data_img[i]
                if k['file_name'] == name:
                    indx2 = k['id']
                    val_list.append(indx2)
                    copy2(fileName, os.path.join(validDir, all_data_img[i]))
                    continue
        num += 1

    print("train_nums", len(train_list))
    print("val_nums", len(val_list))
    # images
    for i in data['images']:
        if i['id'] in train_list:
            train_json_dict['images'].append(i)
        if i['id'] in val_list:
            val_json_dict['images'].append(i)

    # annotations
    for j in data['annotations']:
        # j['category_id'] -= 1  # 类别从0开始
        if j["image_id"] in train_list:
            train_json_dict['annotations'].append(j)
        if j["image_id"] in val_list:
            val_json_dict['annotations'].append(j)

    # categories ,类别从0开始
    for k in data['categories']:
        # k['id'] -= 1
        train_json_dict['categories'].append(k)
        val_json_dict['categories'].append(k)

    with open(os.path.join(args.output, "train.json"), "w") as f:
        json.dump(train_json_dict, f, indent=2)

    with open(os.path.join(args.output, "val.json"), "w") as f:
        json.dump(val_json_dict, f, indent=2)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Start convert.')
    parser.add_argument('--img_path', type=str, default='G:/TianChi_SmallSample/round1_train_202204/train/images/', help='raw images path')  # json文件路径
    parser.add_argument('--img_path2', type=str, default='G:/TianChi_SmallSample/round1_train_202204/train/val_train/', help='raw images path')  # json文件路径
    parser.add_argument('--json_file', type=str, default='G:/TianChi_SmallSample/round1_train_202204/train/annotations/instances_train2017.json', help='json file path')  # json文件路径
    parser.add_argument('--output', type=str, help='output path', default='G:/TianChi_SmallSample/round1_train_202204/train/annotations/')  # 输出的 txt 文件路径
    args = parser.parse_args()
    main(args)

你可能感兴趣的:(MMdetection,python,深度学习)