机器学习基础(1)——数据集下载与预处理

1.前言

   为了后续统一数据集进行测试,本系列统一采用开源的花分类数据集,下载地址为

https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz

(有可能因为境外服务器的原因下载失败,此处给出flower_photos.tgz百度网盘分享链接:https://pan.baidu.com/s/1OdyKGBQ7wZI_52xAvEI2FQ 
提取码:8idw 
--来自百度网盘超级会员V4的分享)

2.数据预处理

解压下载到的压缩包后,进行数据预处理,按照一定比例(一般为8:2或9:1)拆分数据集

具体代码如下,视不同情况需要自己调整数据集的路径,如果相对路径不行可以尝试绝对路径

import os
from shutil import copy, rmtree
import random


def mk_file(file_path: str):
    if os.path.exists(file_path):
        rmtree(file_path)
    os.makedirs(file_path)


def main():
    random.seed(0)

    # 划分比例
    split_rate = 0.1

    cwd = os.getcwd()
    data_root = os.path.join(cwd, "flower_data")
    origin_flower_path = os.path.join(data_root, "flower_photos")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)

    flower_class = [cla for cla in os.listdir(origin_flower_path)
                    if os.path.isdir(os.path.join(origin_flower_path, cla))]

    # 训练集文件夹
    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in flower_class:
        mk_file(os.path.join(train_root, cla))

    # 验证集文件夹
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(val_root, cla))

    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path, cla)
        images = os.listdir(cla_path)
        num = len(images)
        # 随机采样
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()

    print("processing done!")


if __name__ == '__main__':
    main()

运行后数据集目录如图所示:

机器学习基础(1)——数据集下载与预处理_第1张图片

 

你可能感兴趣的:(人工智能,python)