Pytorch学习系列之数据处理(1)

从这篇开始将记录自己Pytorch的学习,我主要是基于一份Pytorch学习的文档,自己重新手敲一遍里面所有的代码,并且根据自己的基础,加上了一些相应的注释,使得自己可以更加深刻的理解以及记住一些自己不知道的知识。

  • 文档及代码:https://github.com/tensor-yu/PyTorch_Tutorial 可以从这里下载

话不多说开始学习了!
Pytorch学习系列之数据处理(1)_第1张图片

  • 任务:将cifar10的data_batch12345以及test_batch,转换成png格式的图片,每个类别的图片在一个文件夹中,文件夹命名为0-9
  • cifar10数据介绍及下载:https://www.cs.toronto.edu/~kriz/cifar.html
  • 不过建议使用百度云连接下载,会更快一点,文档中也提供了百度云链接
from scipy.misc import imsave  # scipy是基于numpy基础之上的科学计算库 ,从这个库中引入保存图像的函数
import numpy as np  # numpy也是一个科学运算库,支持多维数组和矩阵的运算
import os  # 操作系统接口包
import pickle  # 加工数据用的,可以用来存储结构化数据
from tqdm import tqdm  # tqdm是可以显示循环进度的模块

# 需要处理的数据的目录
data_dir = os.path.join("Data", "cifar-10-batches-py") # 使用os.path.join()函数对路径直接进行拼接,中间的参数是字符串,用逗号隔开
# 训练集和测试集输出的目录
train_dir = os.path.join("Data", "cifar10-png", "train_data")
test_dir = os.path.join("Data", "cifar10-png", "test_data")

Train = True # 不接呀训练集,只解压测试集

# 解压缩函数,返回解压缩后的字典
def Unpickle(file):
    with open(file, 'rb') as fo: # 使用with open()以二进制形式读取文件
        dict_ = pickle.load(fo, encoding='bytes')
    return dict_


# 创建需要的目录
def Create_dir(dir):
    if not os.path.isdir(dir): # 判断如果目录不存在,则创建新的目录
        os.makedirs(dir)


# 主函数,生成数据集的图片
if __name__ ==  "__main__":
    if Train:
        for i in range(1, 6):  # 注意:循环1-5
            train_data_path = os.path.join(data_dir, "data_batch_" + str(i))
            train_data = Unpickle(train_data_path)  # 读取出来的数据是一个字典
            # 对于cifar10数据中读取出来的字典中主要有用的键是b'data', 和b'labels',分别表示的是数据和对应的label

            for j in tqdm(range(0, 10000)):
                train_img = np.reshape(train_data[b'data'][j], (3, 32, 32)).transpose(1, 2, 0)  # 装换成(32, 32, 3)
                train_label_num = str(train_data[b'labels'][j]) # 把label_num转换成str用于下面创建文件

                # 此时已经把图像和标签都提取出来了,现在就是要保存到指定文件夹
                train_o_dir = os.path.join(train_dir, train_label_num)
                Create_dir(train_o_dir) # 创建保存图像的文件夹

                # 设置保存图像的名字
                train_img_name = train_label_num + "_" + str(i + (j - 1) * 100) + ".png"  # 这里str()是为了更好的标识图像,前四位是用来表示第几章图像,后三位表示的是类别
                # train_img_name = train_label_num + "_" + str(i) + ".png"  # 直接这样写也是一样的
                train_img_path = os.path.join(train_o_dir, train_img_name)

                # 保存图片
                imsave(train_img_path, train_img)

    # 测试集数据处理
    test_data_path = os.path.join(data_dir, "test_batch")
    test_data = Unpickle(test_data_path)  # 返回的是字典,因为只有一个test_batch所以不需要for循环

    for k in tqdm(range(0, 10000)):
        test_img = np.reshape(test_data[b'data'][i], (3, 32, 32)).transpose(1, 2, 0) # 装换成(32, 32, 3)
        test_data_labels = str(test_data[b'labels'][i])

        test_o_path = os.path.join(test_dir, test_data_labels) # 对于测试集也是相同的类别放在相同的文件夹下
        Create_dir(test_o_path)

        test_img_name = test_data_labels + "_" + str(i) + ".png"
        test_img_path = os.path.join(test_o_path, test_img_name)
        imsave(test_img_path, test_img)
        # 保存图片到指定文件夹

在跑代码的时候遇到了库的问题,出现了:

from scipy.misc import imsave
ImportError: cannot import name ‘imsave’

解决办法,参考博客:https://blog.csdn.net/qq_32324999/article/details/98986477

你可能感兴趣的:(DL,Pytorch)