深度学习:分类数据集划分python代码

深度学习:分类数据集的划分(python代码)

说明

这个代码是用来划分分类数据集的

"""
this is a code for split your datasets.for example, a floder which contain

some folders that is lable, so you need to split it be a train and val

"""
import argparse  #参数解析器,可以使用
import os
import random
import shutil
import tqdm

def load_args():
    #初始一个参数解析器的容器空间:parser
    parser = argparse.ArgumentParser()
    #添加参数
    parser.add_argument('--path',type=str,default=r'C:\Users\Administrator\Desktop\train')
    parser.add_argument('--radtio',type=float,default=0.8) #this is the spliting radtio
    parser.add_argument('--dataset_type',type=str,default='train_val') #train-clas1
                                                                       #      -clas2
    parser.add_argument('--save',type=str,default=r'C:\Users\Administrator\Desktop\data')
    #...others args
    #生成解析器接口
    args = parser.parse_args()
    return args

def split(args):
    classname = os.listdir(args.path)
    for class_folder in classname:
        #对其中的一个类别进行划分
        epath = os.path.join(args.path,class_folder) #路径
        e_nums = len(os.listdir(epath))  #每一类的图像数量
        train_nums = int(e_nums*args.radtio) # 训练集的数量
        val_nums = e_nums - train_nums   #验证集的数量

        #随机挑选并复制粘贴
        train_list = random.sample(range(0,e_nums),train_nums)
        val_list = []
        for i in range(0,e_nums):
            if (i not in train_list):
                val_list.append(i)

        #复制粘贴
        if not (os.path.exists(os.path.join(args.save,'val'))): #若没有建立该文件夹
            os.mkdir(os.path.join(args.save, 'val'))
        if not (os.path.exists(os.path.join(args.save,'train'))): #若没有建立该文件夹
            os.mkdir(os.path.join(args.save, 'train'))
        if not (os.path.exists(os.path.join(args.save,'val',class_folder))): #若没有建立该文件夹
            os.mkdir(os.path.join(args.save,'val',class_folder))
        if not (os.path.exists(os.path.join(args.save,'train',class_folder))): #若没有建立该文件夹
            os.mkdir(os.path.join(args.save,'train',class_folder))
        train_save = os.path.join(args.save,'train',class_folder)
        val_save = os.path.join(args.save,'val',class_folder)
        name = os.listdir(epath)
        for i in train_list:
            shutil.copy(os.path.join(epath,name[i]),os.path.join(train_save,name[i]))
        for i in val_list:
            shutil.copy(os.path.join(epath,name[i]),os.path.join(val_save,name[i]))
        print('{}:已完成划分'.format(class_folder))

if  __name__ == '__main__':
     args = load_args()
     split(args)

你可能感兴趣的:(python,深度学习吧,python,深度学习,分类)