Pytorch模型训练实用教程-代码解读(2)

# coding: utf-8
"""
    将原始数据集进行划分成训练集、验证集和测试集
"""

import os
import glob
import random
import shutil

dataset_dir = os.path.join("..", "..", "Data", "cifar-10-png", "raw_test")
train_dir = os.path.join("..", "..", "Data", "train")
valid_dir = os.path.join("..", "..", "Data", "valid")
test_dir = os.path.join("..", "..", "Data", "test")

train_per = 0.8
valid_per = 0.1
test_per = 0.1


def makedir(new_dir):  # 目录是否存在,如果不存在则创建
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)


if __name__ == '__main__':

    for root, dirs, files in os.walk(dataset_dir):  # dirs=0-9
        for sDir in dirs:
            imgs_list = glob.glob(os.path.join(root, sDir, '*.png'))  # 第0个文件夹里面的照片列表 第2个。。。。
            random.seed(666)
            random.shuffle(imgs_list)
            imgs_num = len(imgs_list)  # 每个0个文件夹中的照片个数

            train_point = int(imgs_num * train_per)
            valid_point = int(imgs_num * (train_per + valid_per))

            for i in range(imgs_num):
                if i < train_point:
                    out_dir = os.path.join(train_dir, sDir)  #  train_dir/0
                elif i < valid_point:
                    out_dir = os.path.join(valid_dir, sDir)
                else:
                    out_dir = os.path.join(test_dir, sDir)

                makedir(out_dir)  # 生成test,train,valid,和里面的0-9 文件夹
                out_path = os.path.join(out_dir, os.path.split(imgs_list[i])[-1])  # 从imgs_list里面一个一个取,要png的名字
                # train_dir/0/1.png
                shutil.copy(imgs_list[i], out_path)
                # raw_test/0/1.png 复制到 train_dir/0/1.png  
            print('Class:{}, train:{}, valid:{}, test:{}'.format(sDir, train_point, valid_point-train_point, imgs_num-valid_point))

问题一:

Pytorch模型训练实用教程-代码解读(2)_第1张图片

问题二Pytorch模型训练实用教程-代码解读(2)_第2张图片

Pytorch模型训练实用教程-代码解读(2)_第3张图片

Pytorch模型训练实用教程-代码解读(2)_第4张图片

Pytorch模型训练实用教程-代码解读(2)_第5张图片问题三

 random.shuffle()用于将一个列表中的元素打乱顺序,值得注意的是使用这个方法不会生成新的列表,只是将原列表的次序打乱

问题四

os.path.split()函数
将文件名和路径分割开。
语法:os.path.split(‘PATH’)
参数说明:

PATH指一个文件的全路径作为参数:
如果给出的是一个目录和文件名,则输出路径和文件名
如果给出的是一个目录名,则输出路径和为空文件名

实际上,该函数的分割并不智能,它仅仅是以 “PATH” 中最后一个 ‘/’ 作为分隔符,
分隔后,将索引为0的视为目录(路径),将索引为1的视为文件名,如:

>>> import os
>>> os.path.split('C:/soft/python/test.py')
('C:/soft/python', 'test.py')
>>> os.path.split('C:/soft/python/test')
('C:/soft/python', 'test')
>>> os.path.split('C:/soft/python/')
('C:/soft/python', '')

问题五

Pytorch模型训练实用教程-代码解读(2)_第6张图片

你可能感兴趣的:(pytorch模型训练实用教程,python,机器学习,开发语言)