目的:阅读作者写的代码,并将不懂的地方记录下来
项目链接:https://github.com/TingsongYu/PyTorch_Tutorial
1_1_cifar10_to_png.py
# coding:utf-8
"""
将cifar10的data_batch_12345 转换成 png格式的图片
每个类别单独存放在一个文件夹,文件夹名称为0-9
"""
from imageio import imwrite
import numpy as np
import os
import pickle
data_dir = os.path.join("..", "..", "Data", "cifar-10-batches-py")
train_o_dir = os.path.join("..", "..", "Data", "cifar-10-png", "raw_train")
test_o_dir = os.path.join("..", "..", "Data", "cifar-10-png", "raw_test")
Train = False # 不解压训练集,仅解压测试集
# 解压缩,返回解压后的字典
def unpickle(file):
with open(file, 'rb') as fo:
dict_ = pickle.load(fo, encoding='bytes')
return dict_
def my_mkdir(my_dir):
if not os.path.isdir(my_dir):
os.makedirs(my_dir)
# 生成训练集图片,
if __name__ == '__main__':
if Train:
for j in range(1, 6):
data_path = os.path.join(data_dir, "data_batch_" + str(j)) # data_batch_12345
train_data = unpickle(data_path)
print(data_path + " is loading...")
for i in range(0, 10000):
img = np.reshape(train_data[b'data'][i], (3, 32, 32))
img = img.transpose(1, 2, 0)
label_num = str(train_data[b'labels'][i])
o_dir = os.path.join(train_o_dir, label_num)
my_mkdir(o_dir)
img_name = label_num + '_' + str(i + (j - 1)*10000) + '.png'
img_path = os.path.join(o_dir, img_name)
imwrite(img_path, img)
print(data_path + " loaded.")
print("test_batch is loading...")
# 生成测试集图片
test_data_path = os.path.join(data_dir, "test_batch")
test_data = unpickle(test_data_path)
for i in range(0, 10000):
img = np.reshape(test_data[b'data'][i], (3, 32, 32))
img = img.transpose(1, 2, 0)
label_num = str(test_data[b'labels'][i])
o_dir = os.path.join(test_o_dir, label_num)
my_mkdir(o_dir)
img_name = label_num + '_' + str(i) + '.png'
img_path = os.path.join(o_dir, img_name)
imwrite(img_path, img)
print("test_batch loaded.")
img = img.transpose(1, 2, 0)将图片的数据格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),关于函数transpose()可见博客:函数解释
reshape的第一个参数只能是通道数
为什么要将图片格式转化为(imagesize,imagesize,channels)呢?
可能是后面用到的函数imwrite()接受的图片格式只能是这样的。
这里的imwrite函数不是平常用的opencv中的那个,而是imageio库中的,不知道二者有什么区别
构造data_dir,train_o_dir,test_o_dir路径时,为什么要写两个‘..’
\Data\cifar-10-batches-py中有6个未解压的文件,5个data_batch_12345中是训练集的图片,共50000张,test_batch中是测试集中的图片共10000张。
代码通过一个循环将5个data_batch_12345中的图片解压缩,然后对每个data_batch_i中的图片进行处理:
1.通过reshape函数将图片转化为(3,32,32)格式的,因为解压后的图片像素应该是一长串的那种,要通过reshape转化为所需的格式才能进行处理。
2.通过imwrite存储图片。存储图片之前的代码就是在创建每张图片将要存储的地址
处理test_batch中的图片与上述操作相同,处理完的图片存储在文件\Data\cifar-10-png\raw_test和\Data\cifar-10-png\raw_train中。为了方便起见,这里只处理了测试集中的图片,运行后结果如下。数字分别代表不同类别的图片。
1_2_split_dataset.py
# coding: utf-8
"""
将原始数据集进行划分成训练集、验证集和测试集
"""
import os
import glob
import random
import shutil
dataset_dir = os.path.join("..", "..", "Data", "cifar-10-png", "raw_test")
train_dir = os.path.join("..", "..", "Data", "train")
valid_dir = os.path.join("..", "..", "Data", "valid")
test_dir = os.path.join("..", "..", "Data", "test")
train_per = 0.8
valid_per = 0.1
test_per = 0.1
def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(new_dir)
if __name__ == '__main__':
for root, dirs, files in os.walk(dataset_dir):
for sDir in dirs:
imgs_list = glob.glob(os.path.join(root, sDir, '*.png'))
random.seed(666)
random.shuffle(imgs_list)
imgs_num = len(imgs_list)
train_point = int(imgs_num * train_per)
valid_point = int(imgs_num * (train_per + valid_per))
for i in range(imgs_num):
if i < train_point:
out_dir = os.path.join(train_dir, sDir)
elif i < valid_point:
out_dir = os.path.join(valid_dir, sDir)
else:
out_dir = os.path.join(test_dir, sDir)
makedir(out_dir)
out_path = os.path.join(out_dir, os.path.split(imgs_list[i])[-1])
shutil.copy(imgs_list[i], out_path)
print('Class:{}, train:{}, valid:{}, test:{}'.format(sDir, train_point, valid_point-train_point, imgs_num-valid_point))
os.walk()函数的作用:参考
因此代码中dirs返回的是0123456789共10个子目录的名字
glob.glob()函数的作用:参考
可以将某目录下所有跟通配符模式相同的文件放到一个列表中,有了这个函数,我们再想生成所有文件的列表就不需要使用for循环遍历目录了
因此 imgs_list中返回的是所有的图片路径
shutil.copy(src, dst)
复制文件内容(不包含元数据)从src到dst。
将0、1、2、3、4、5、6、7、8、9中的图片分别划分为训练集、验证集、测试集。最后得到train、valid、test三个文件夹,每个文件夹下又按图片类别分为10类。