划分训练集、验证集和测试集代码


输入需要划分的数据所在的文件夹,返回一个划分好的包含train,val和test的文件夹。

val_ratio,和test_ratio分别为验证集和测试集所占的比例,test_ratio=0表示不划分测试集。

import os
import cv2
import random
import shutil
import numpy as np
from tqdm import tqdm
from collections import defaultdict
'''
源文件夹下的图片和标签文件需要命名为images和labels
记得修改image的扩展名:jpg  or  png
'''



def main(src_dir, dst_dir, val_ratio=0.1, test_ratio=0):
    src_imgs_dir = os.path.join(src_dir, 'images')
    src_labels_dir = os.path.join(src_dir, 'labels')
    fids = [f for f in os.listdir(src_imgs_dir) if f.endswith('.png')]
    random.seed(6)
    random.shuffle(fids)
    total_num = len(fids)
    val_num = int(total_num * val_ratio)
    test_num = int(total_num * test_ratio)
    train_num = total_num - val_num - test_num

    train_set = fids[:train_num]
    val_set = fids[train_num: train_num + val_num]
    test_set = fids[train_num + val_num:]

    print('train num: {}  val num: {} test num: {}'.format(len(train_set), len(val_set), len(test_set)))

    out_train_dir = os.path.join(dst_dir, 'train')
    out_val_dir = os.path.join(dst_dir, 'val')

    os.makedirs(os.path.join(out_train_dir, 'images'), exist_ok=True)
    os.makedirs(os.path.join(out_train_dir, 'labels'), exist_ok=True)
    os.makedirs(os.path.join(out_val_dir, 'images'), exist_ok=True)
    os.makedirs(os.path.join(out_val_dir, 'labels'), exist_ok=True)


    for fid in tqdm(train_set):
        shutil.copy(os.path.join(src_imgs_dir, fid), os.path.join(out_train_dir, 'images', fid))
        shutil.copy(os.path.join(src_labels_dir, fid), os.path.join(out_train_dir, 'labels', fid))
    for fid in tqdm(val_set):
        shutil.copy(os.path.join(src_imgs_dir, fid), os.path.join(out_val_dir, 'images', fid))
        shutil.copy(os.path.join(src_labels_dir, fid), os.path.join(out_val_dir, 'labels', fid))
    if test_ratio:
        out_test_dir = os.path.join(dst_dir, 'test')
        os.makedirs(os.path.join(out_test_dir, 'images'), exist_ok=True)
        os.makedirs(os.path.join(out_test_dir, 'labels'), exist_ok=True)
        for fid in tqdm(test_set):
            shutil.copy(os.path.join(src_imgs_dir, fid), os.path.join(out_test_dir, 'images', fid))
            shutil.copy(os.path.join(src_labels_dir, fid), os.path.join(out_test_dir, 'labels', fid))
    print('done!')


if __name__ == '__main__':

    src_dir = r'E:\New_gj\dataset\T_multi_22\cut'             #源文件
    dst_dir = r'E:\New_gj\dataset\T_multi_22\split'            #目标目录
    main(src_dir, dst_dir, val_ratio=0.1, test_ratio=0)

你可能感兴趣的:(数据集,python,机器学习,深度学习)