【PyTorch模型训练实用教程】01 将cifar10图片转化为png+训练集、验证集、测试集划分

目的:阅读作者写的代码,并将不懂的地方记录下来

项目链接:https://github.com/TingsongYu/PyTorch_Tutorial

1_1_cifar10_to_png.py

# coding:utf-8
"""
    将cifar10的data_batch_12345 转换成 png格式的图片
    每个类别单独存放在一个文件夹,文件夹名称为0-9
"""
from imageio import imwrite
import numpy as np
import os
import pickle


data_dir = os.path.join("..", "..", "Data", "cifar-10-batches-py")
train_o_dir = os.path.join("..", "..", "Data", "cifar-10-png", "raw_train")
test_o_dir = os.path.join("..", "..", "Data", "cifar-10-png", "raw_test")

Train = False   # 不解压训练集,仅解压测试集

# 解压缩,返回解压后的字典
def unpickle(file):
    with open(file, 'rb') as fo:
        dict_ = pickle.load(fo, encoding='bytes')
    return dict_

def my_mkdir(my_dir):
    if not os.path.isdir(my_dir):
        os.makedirs(my_dir)


# 生成训练集图片,
if __name__ == '__main__':
    if Train:
        for j in range(1, 6):
            data_path = os.path.join(data_dir, "data_batch_" + str(j))  # data_batch_12345
            train_data = unpickle(data_path)
            print(data_path + " is loading...")

            for i in range(0, 10000):
                img = np.reshape(train_data[b'data'][i], (3, 32, 32))
                img = img.transpose(1, 2, 0)

                label_num = str(train_data[b'labels'][i])
                o_dir = os.path.join(train_o_dir, label_num)
                my_mkdir(o_dir)

                img_name = label_num + '_' + str(i + (j - 1)*10000) + '.png'
                img_path = os.path.join(o_dir, img_name)
                imwrite(img_path, img)
            print(data_path + " loaded.")

    print("test_batch is loading...")

    # 生成测试集图片
    test_data_path = os.path.join(data_dir, "test_batch")
    test_data = unpickle(test_data_path)
    for i in range(0, 10000):
        img = np.reshape(test_data[b'data'][i], (3, 32, 32))
        img = img.transpose(1, 2, 0)

        label_num = str(test_data[b'labels'][i])
        o_dir = os.path.join(test_o_dir, label_num)
        my_mkdir(o_dir)

        img_name = label_num + '_' + str(i) + '.png'
        img_path = os.path.join(o_dir, img_name)
        imwrite(img_path, img)

    print("test_batch loaded.")

第一个不懂的地方:

img = img.transpose(1, 2, 0)将图片的数据格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),关于函数transpose()可见博客:函数解释

reshape的第一个参数只能是通道数

为什么要将图片格式转化为(imagesize,imagesize,channels)呢?

可能是后面用到的函数imwrite()接受的图片格式只能是这样的。

第二个不懂的地方:

这里的imwrite函数不是平常用的opencv中的那个,而是imageio库中的,不知道二者有什么区别

第三个不懂的地方:

构造data_dir,train_o_dir,test_o_dir路径时,为什么要写两个‘..’

代码内容概述:

\Data\cifar-10-batches-py中有6个未解压的文件,5个data_batch_12345中是训练集的图片,共50000张,test_batch中是测试集中的图片共10000张。

【PyTorch模型训练实用教程】01 将cifar10图片转化为png+训练集、验证集、测试集划分_第1张图片

代码通过一个循环将5个data_batch_12345中的图片解压缩,然后对每个data_batch_i中的图片进行处理:

1.通过reshape函数将图片转化为(3,32,32)格式的,因为解压后的图片像素应该是一长串的那种,要通过reshape转化为所需的格式才能进行处理。

2.通过imwrite存储图片。存储图片之前的代码就是在创建每张图片将要存储的地址

处理test_batch中的图片与上述操作相同,处理完的图片存储在文件\Data\cifar-10-png\raw_test和\Data\cifar-10-png\raw_train中。为了方便起见,这里只处理了测试集中的图片,运行后结果如下。数字分别代表不同类别的图片。

【PyTorch模型训练实用教程】01 将cifar10图片转化为png+训练集、验证集、测试集划分_第2张图片

 1_2_split_dataset.py

# coding: utf-8
"""
    将原始数据集进行划分成训练集、验证集和测试集
"""

import os
import glob
import random
import shutil

dataset_dir = os.path.join("..", "..", "Data", "cifar-10-png", "raw_test")
train_dir = os.path.join("..", "..", "Data", "train")
valid_dir = os.path.join("..", "..", "Data", "valid")
test_dir = os.path.join("..", "..", "Data", "test")

train_per = 0.8
valid_per = 0.1
test_per = 0.1


def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)


if __name__ == '__main__':

    for root, dirs, files in os.walk(dataset_dir):
        for sDir in dirs:
            imgs_list = glob.glob(os.path.join(root, sDir, '*.png'))
            random.seed(666)
            random.shuffle(imgs_list)
            imgs_num = len(imgs_list)

            train_point = int(imgs_num * train_per)
            valid_point = int(imgs_num * (train_per + valid_per))

            for i in range(imgs_num):
                if i < train_point:
                    out_dir = os.path.join(train_dir, sDir)
                elif i < valid_point:
                    out_dir = os.path.join(valid_dir, sDir)
                else:
                    out_dir = os.path.join(test_dir, sDir)

                makedir(out_dir)
                out_path = os.path.join(out_dir, os.path.split(imgs_list[i])[-1])
                shutil.copy(imgs_list[i], out_path)

            print('Class:{}, train:{}, valid:{}, test:{}'.format(sDir, train_point, valid_point-train_point, imgs_num-valid_point))

第一个不懂的地方:

os.walk()函数的作用:参考

【PyTorch模型训练实用教程】01 将cifar10图片转化为png+训练集、验证集、测试集划分_第3张图片

 因此代码中dirs返回的是0123456789共10个子目录的名字

第二个不懂的地方:

glob.glob()函数的作用:参考

可以将某目录下所有跟通配符模式相同的文件放到一个列表中,有了这个函数,我们再想生成所有文件的列表就不需要使用for循环遍历目录了

因此 imgs_list中返回的是所有的图片路径

第三个不懂的地方:

shutil.copy(src, dst)

复制文件内容(不包含元数据)从src到dst。

代码内容概述:

将0、1、2、3、4、5、6、7、8、9中的图片分别划分为训练集、验证集、测试集。最后得到train、valid、test三个文件夹,每个文件夹下又按图片类别分为10类。

你可能感兴趣的:(PyTorch模型训练实用教程,pytorch,深度学习,python)