划分数据集

import os
from shutil import copy, rmtree
import random
'''
划分数据集
'''

def mk_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在,则先删除原文件夹在重新创建
        rmtree(file_path)
    os.makedirs(file_path)


def main():
    # 保证随机可复现
    random.seed(0)

    # 将数据集中10%的数据划分到验证集中
    split_rate = 0.3

    #
    cwd = os.getcwd()
    data_root = os.path.join(cwd, "spe_data")
    dataset_root = os.path.join(cwd, "spe_dataset")

    assert os.path.exists(data_root), "path '{}' does not exist.".format(data_root)

    ship_class = [cla for cla in os.listdir(data_root)
                    if os.path.isdir(os.path.join(data_root, cla))]

    # 建立保存训练集的文件夹
    train_root = os.path.join(dataset_root, "train")
    mk_file(train_root)
    for cla in ship_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(train_root, cla))

    # 建立保存验证集的文件夹
    val_root = os.path.join(dataset_root, "val")
    mk_file(val_root)
    for cla in ship_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(val_root, cla))

    for cla in ship_class:
        cla_path = os.path.join(data_root, cla)
        images = os.listdir(cla_path)
        num = len(images)
        # 随机采样验证集的索引
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                # 将分配至验证集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                # 将分配至训练集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()

    print("processing done!")


if __name__ == '__main__':
    main()

你可能感兴趣的:(python,人工智能,机器学习)