PyTorch 模型训练实用教程(一):数据

目录

Cifar10 转 png

第一步:下载 cifar-10-python.tar.gz

第二步:运行 1_1_cifar10_to_png.py

主要模块

scipy.misc.imsave()函数:

pickle模块

 os.path.join

第三步: 训练集、验证集和测试集的划分

主要模块

glob模块

 shutil模块

os.walk()方法

split()和os.path.split()

让 PyTorch 能读数据集

Dataset 类

1. 制作图片数据的索引

2. 构建 Dataset 子类

图片从硬盘到模型


Cifar10 png

为了统一数据,这里采用 cifar-10 的测试集,共 10000 张图片 作为源数据,模拟真实场景中的数据。

这里主要介绍一下这个数据集的目录结构以及内部数据组织格式.

第一步:下载 cifar-10-python.tar.gz

下载 cifar-10-python.tar.gz ,存放到 /Data 文件夹下,并且解压,获得文件夹 /Data/cifar-10-batches-py/
下载方式:
1. 官网: http://www.cs.toronto.edu/~kriz/cifar.html

第二步:运行 1_1_cifar10_to_png.py

1_1_cifar10_to_png.py代码如下:

# coding:utf-8
"""
    将cifar10的data_batch_12345 转换成 png格式的图片
    每个类别单独存放在一个文件夹,文件夹名称为0-9
"""
from scipy.misc import imsave
import numpy as np
import os
import pickle


data_dir = os.path.join("..", "..", "Data", "cifar-10-batches-py")
train_o_dir = os.path.join("..", "..", "Data", "cifar-10-png", "raw_train")
test_o_dir = os.path.join("..", "..", "Data", "cifar-10-png", "raw_test")

Train = False   # 不解压训练集,仅解压测试集

# 解压缩,返回解压后的字典
def unpickle(file):
    with open(file, 'rb') as fo:
        dict_ = pickle.load(fo, encoding='bytes')
    return dict_

def my_mkdir(my_dir):
    if not os.path.isdir(my_dir):
        os.makedirs(my_dir)


# 生成训练集图片,
if __name__ == '__main__':
    if Train:
        for j in range(1, 6):
            data_path = os.path.join(data_dir, "data_batch_" + str(j))  # data_batch_12345
            train_data = unpickle(data_path)
            print(data_path + " is loading...")

            for i in range(0, 10000):
                img = np.reshape(train_data[b'data'][i], (3, 32, 32))
                img = img.transpose(1, 2, 0)

                label_num = str(train_data[b'labels'][i])
                o_dir = os.path.join(train_o_dir, label_num)
                my_mkdir(o_dir)

                img_name = label_num + '_' + str(i + (j - 1)*10000) + '.png'
                img_path = os.path.join(o_dir, img_name)
                imsave(img_path, img)
            print(data_path + " loaded.")

    print("test_batch is loading...")

    # 生成测试集图片
    test_data_path = os.path.join(data_dir, "test_batch")
    test_data = unpickle(test_data_path)
    for i in range(0, 10000):
        img = np.reshape(test_data[b'data'][i], (3, 32, 32))
        img = img.transpose(1, 2, 0)

        label_num = str(test_data[b'labels'][i])
        o_dir = os.path.join(test_o_dir, label_num)
        my_mkdir(o_dir)

        img_name = label_num + '_' + str(i) + '.png'
        img_path = os.path.join(o_dir, img_name)
        imsave(img_path, img)

    print("test_batch loaded.")

主要模块

scipy.misc.imsave()函数:

这个函数用于储存图片,将数组保存为图像。此功能仅在安装了Python Imaging Library(PIL)时可用。新的替代它的是imageio.imwrite()
用法:imsave(name, arr, format=None)
参数:

  • name : 文件名或者文件名加目录
  • arr:np-array的矩阵,MxN or MxNx3 or MxNx4这三种格式,分别对应灰度图像,RGB图像和RGB+alpha图像
  • format :str型,图像输出的类型,省略的话,图片直接输出图片的扩展名。
#灰度图像
from scipy.misc import imsave
x = np.zeros((255, 255))
x = np.zeros((255, 255), dtype=np.uint8)
x[:] = np.arange(255)
imsave('gradient.png', x)

#RGB图像
rgb = np.zeros((255, 255, 3), dtype=np.uint8)
rgb[..., 0] = np.arange(255)
rgb[..., 1] = 55
rgb[..., 2] = 1 - np.arange(255)
imsave('rgb_gradient.png', rgb)

值得注意的是,这个函数默认的情况下,会检测你输入的RGB值的范围,如果都在0到1之间的话,那么会自动扩大范围至0到255。也就是说,这个时候你乘不乘255输出图片的效果一样的。

pickle模块

本次主要使用pickle模块将下载的序列化图像信息读出来,转变为bytes编码。

1

2

3

4

5

def unpickle(file):

    import pickle

    with open(file, 'rb') as fo:

        dict = pickle.load(fo, encoding='bytes')

    return dict

  Python提供的pickle模块可以序列化对象并保存到磁盘中,并在需要的时候读取出来,任何对象都可以执行序列化操作。

1、pickle.dump(object, file, protocol=)  将object对象序列化到打开的文件夹file中。protocol是序列化协议,默认是0,如果是负数或者HIGHEST_PROTOCOL,则使用最高版本序列化协议。

2、pickle.load(file,encoding)  把file中的对象读出,encoding 参数可置为 'bytes' 来将这些 8 位字符串实例读取为字节对象。

3、pickle.dumps(obj,protocol=None) 将 obj 打包以后的对象作为 'bytes'类型直接返回,而不是将其写入到文件。

4、pickle.loads(bytes_object)  对于打包生成的对象 bytes_object,还原出原对象的结构并返回。

dump() 和 load() 与 dumps() 和 loads()的区别 dump()函数能一个接着一个地将几个对象序列化存储到同一个文件中,随后调用load()来以同样的顺序反序列化读出这些对象。

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

import pickle

import os

listdata = [[1,2,3],

       [2,3,4],

       5,6,7]]

 

# f = "{}/{}".format(r"./practice_data",r"pra.txt")

#

# if not os.path.exists(f):

#     os.md(f)

#将数据序列化

fw = open(r"practice_data/pra.txt",'wb')

pickle.dump(listdata,fw,-1)

fw.close()

#将序列化数据读出

fr = open(r"practice_data/pra.txt",'rb')

fd = pickle.load(fr)

print(fd)

fr.close()

 

#使用dumps和loads举例

a = pickle.dumps(listdata)

print(pickle.loads(a))

 输出:

1

2

[[1, 2, 3], [2, 3, 4], [5, 6, 7]]

[[1, 2, 3], [2, 3, 4], [5, 6, 7]]

 os.path.join

python中有join()和os.path.join()两个函数,具体作用如下:

  • join(): 连接字符串数组。将字符串、元组、列表中的元素以指定的字符(分隔符)连接生成一个新的字符串

‘sep’.join(seq)
参数:
sep:分隔符。可以为空
seq:要连接的元素序列、字符串、元组、字典
上面的语法即:以sep作为分隔符,将seq所有的元素合并成一个新的字符串
返回值:返回一个以分隔符sep连接各个元素后生成的字符串

  • os.path.join(): 将多个路径组合后返回

os.path.join()函数
语法: os.path.join(path1[,path2[,……]])
返回值:将多个路径组合后返回

os.path.join()函数:连接两个或更多的路径名组件

  • 如果各组件名首字母不包含’/’,则函数会自动加上
  • 如果有一个组件是一个绝对路径,则在它之前的所有组件均会被舍弃
  • 如果最后一个组件为空,则生成的路径以一个’/’分隔符结尾
  • 注意:若出现”./”开头的参数,会从”./”开头的参数的上一个参数开始拼接

运行完毕后,可在文件夹 Data/cifar-10-png/raw_test/下看到 0-9 个文件夹,对应 9 个类别。 脚本中未将训练集解压出来,这里只是为了实验,因此未使用过多的数据。这里仅将测试集中的 10000 张图片解压出来,作为原始图片,将从这 10000 张图片中划分出训练集(train),验证集(valid),测试集(test)


第三步: 训练集、验证集和测试集的划分

上一小节,把 cifar-10 的测试集转换成了 png 图片,充当实验的原始数据。本小节,将把原始数据按 8:1:1 的比例划分为训练集 (train set) 、验证集 (valid/dev set) 和测试集 (test set) 。关于训练集、验证集和测试集的作用,可阅读博客: https://blog.csdn.net/u011995719/article/details/77451213
运行 Code/1_data_prepare/1_2_split_dataset.py ,将会获得以下三个文件夹:
  • /Data/train/
  • /Data/valid/
  • /Data/test/
"""
    将原始数据集进行划分成训练集、验证集和测试集
"""

import os
import glob
import random
import shutil

dataset_dir = os.path.join("..", "..", "Data", "cifar-10-png", "raw_test")
train_dir = os.path.join("..", "..", "Data", "train")
valid_dir = os.path.join("..", "..", "Data", "valid")
test_dir = os.path.join("..", "..", "Data", "test")

train_per = 0.8
valid_per = 0.1
test_per = 0.1


def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)


if __name__ == '__main__':

    for root, dirs, files in os.walk(dataset_dir):
        for sDir in dirs:
            imgs_list = glob.glob(os.path.join(root, sDir, '*.png'))
            random.seed(666)
            random.shuffle(imgs_list)
            imgs_num = len(imgs_list)

            train_point = int(imgs_num * train_per)
            valid_point = int(imgs_num * (train_per + valid_per))

            for i in range(imgs_num):
                if i < train_point:
                    out_dir = os.path.join(train_dir, sDir)
                elif i < valid_point:
                    out_dir = os.path.join(valid_dir, sDir)
                else:
                    out_dir = os.path.join(test_dir, sDir)

                makedir(out_dir)
                #out_path = os.path.join(out_dir, os.path.split(imgs_list[i])[-1])
                #shutil.copy(imgs_list[i], out_path)
                shutil.copy(imgs_list[i], out_dir)

            print('Class:{}, train:{}, valid:{}, test:{}'.format(sDir, train_point, valid_point-train_point, imgs_num-valid_point))

主要模块

glob模块

glob模块用来一次性读取对应文件夹下所有符合要求的子文件夹和子文件夹下的文件列表,常见的两个方法有glob.glob()和glob.iglob(),可以和常用的find功能进行类比,glob支持*?[]这三种通配符

  • *代表0个或多个字符
  • ?代表一个字符
  • [ ]匹配指定范围内的字符,如[0-9]匹配数字

 shutil模块

  • copy()
功能:复制文件
格式:shutil.copy('来源文件','目标地址')

其他用法:python中的shutil模块

os.walk()方法

 os.walk的函数声明为:
walk(top, topdown=True, οnerrοr=None, followlinks=False)
参数

  • top 是遍历的目录地址
  • topdown 为真,则优先遍历top目录,否则优先遍历top的子目录(默认为开启)
  • onerror 需要一个 callable 对象,当walk需要异常时,会调用
  • followlinks 如果为真,则会遍历目录下的快捷方式(linux 下是 symbolic link)实际所指的目录(默认关闭)

os.walk 的返回值是一个生成器(generator),也就是说我们需要不断的遍历它,来获得所有的内容。
每次遍历的对象都是返回的是一个三元组(root,dirs,files)

返回参数说明:

  1. root 所指的是当前正在遍历的这个文件夹的本身的地址
  2. dirs 是一个 list ,内容是该文件夹中所有的目录的名字(不包括子目录)
  3. files 同样是 list , 内容是该文件夹中所有的文件(不包括子目录) files 同样是 list , 内容是该文件夹中所有的文件(不包括子目录)

split()和os.path.split()

  • split():拆分字符串。通过指定分隔符对字符串进行切片,并返回分割后的字符串列表(list)

语法:str.split(str="",num=string.count(str))[n]

参数说明:

  • str:表示为分隔符,默认为空格,但是不能为空(’ ’)。若字符串中没有分隔符,则把整个字符串作为列表的一个元素
  • num:表示分割次数。如果存在参数num,则仅分隔成 num+1 个子字符串,并且每一个子字符串可以赋给新的变量
  • [n]:表示选取第n个分片

注意:当使用空格作为分隔符时,对于中间为空的项会自动忽略

  • os.path.split():按照路径将文件名和路径分割开

语法:os.path.split(‘PATH’)

参数说明:

  • PATH指一个文件的全路径作为参数:
  • 如果给出的是一个目录和文件名,则输出路径和文件名
  • 如果给出的是一个目录名,则输出路径和为空文件名
数据划分完毕,下一步是制作存放有图片路径及其标签的 txt PyTorch 依据该 txt 的信息进行寻找图片,并读取图片数据和标签数据

PyTorch 能读数据集

Dataset

PyTorch 读取图片,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。 Dataset 类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它,类似于 C++ 中的虚基 类.
源码:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
这里重点看 getitem 函数, getitem 接收一个 index ,然后返回图片数据和标签,这个 index 通常指的是一个 list index ,这个 list 的每个元素就包含了图片数据的路径和标签信 息。
然而,如何制作这个 list 呢,通常的方法是将图片的路径和标签信息存储在一个 txt中,然后从该 txt 中读取。
那么读取自己数据的基本流程就是:
1. 制作存储了图片的路径和标签信息的 txt
2. 将这些信息转化为 list ,该 list 每一个元素对应一个样本
3. 通过 getitem 函数,读取数据和标签,并返回数据和标签
在训练代码里是感觉不到这些操作的,只会看到通过 DataLoader 就可以获取一个 batch 的数据,其实触发去读取图片这些操作的是 DataLoader 里的 __iter__(self) ,后面会详 细讲解读取过程。在本小节,主要讲 Dataset 子类。
因此,要让 PyTorch 能读取自己的数据集,只需要两步:
1. 制作图片数据的索引
2. 构建 Dataset 子类

1. 制作图片数据的索引

这个比较简单,就是读取图片路径,标签,保存到 txt 文件中,这里注意格式就好
特别注意的是, txt 中的路径,是以训练时的那个 py 文件所在的目录为工作目录,所以这 里需要提前算好相对路径!
运行代码:
# coding:utf-8
import os

'''
    为数据集生成对应的txt文件
'''

train_txt_path = os.path.join("..", "..", "Data", "train.txt")
train_dir = os.path.join("..", "..", "Data", "train")

valid_txt_path = os.path.join("..", "..", "Data", "valid.txt")
valid_dir = os.path.join("..", "..", "Data", "valid")


def gen_txt(txt_path, img_dir):
    f = open(txt_path, 'w')
    
    for root, s_dirs, _ in os.walk(img_dir, topdown=True):  # 获取 train文件下各文件夹名称
        for sub_dir in s_dirs:
            i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径
            img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径
            for i in range(len(img_list)):
                if not img_list[i].endswith('png'):         # 若不是png文件,跳过
                    continue
                label = img_list[i].split('_')[0]
                img_path = os.path.join(i_dir, img_list[i])
                line = img_path + ' ' + label + '\n'
                f.write(line)
    f.close()


if __name__ == '__main__':
    gen_txt(train_txt_path, train_dir)
    gen_txt(valid_txt_path, valid_dir)

即会在 /Data/ 文件夹下面看到train.txt valid.txt

2. 构建 Dataset 子类

下面是本实验构建的 Dataset 子类 ——MyDataset 类:
# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
    def __init__(self, txt_path, transform = None, target_transform = None):
        fh = open(txt_path, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img, label
    def __len__(self):
        return len(self.imgs)
首先看看初始化,初始化中从我们准备好的 txt 里获取图片的路径和标签,并且存储 self.imgs self.imgs 就是上面提到的 list ,其一个元素对应一个样本的路径和标签,其实 就是 txt 中的一行。
初始化中还会初始化 transform transform 是一个 Compose 类型,里边有一个 list list中就会定义了各种对图像进行处理的操作,可以设置减均值,除标准差,随机裁剪,旋 转,翻转,仿射变换等操作。
在这里我们可以知道,一张图片读取进来之后,会经过数据处理(数据增强),最终变成输入模型的数据。这里就有一点需要注意, PyTorch 的数据增强是将原始图片进行了 处理,并不会生成新的一份图片,而是 覆盖 原图,当采用 randomcrop 之类的随机操作 时,每个 epoch 输入进来的图片几乎不会是一模一样的,这达到了样本多样性的功能。
然后看看核心的 getitem 函数:
第一行: self.imgs 是一个 list ,也就是一开始提到的 list self.imgs 的一个元素是一个 str 包含图片路径,图片标签,这些信息是从 txt 文件中读取
第二行:利用 Image.open 对图片进行读取, img 类型为 Image mode=‘RGB’
第三行与第四行: 对图片进行处理,这个 transform 里边可以实现 减均值,除标准差,随机裁剪,旋转,翻转,放射变换,等等操作,这个放在后面会详细讲解。
Mydataset 构建好,剩下的操作就交给 DataLoder ,在 DataLoder 中,会触发Mydataset 中的 getiterm 函数读取一张图片的数据和标签,并拼接成一个 batch 返回,作为 模型真正的输入。下一小节将会通过一个小例子,介绍 DataLoder 是如何获取一个 batch 以及一张图片是如何被 PyTorch 读取,最终变为模型的输入的。

图片从硬盘到模型

上小节中介绍了如何构建自己的 Dataset 子类 ——MyDataset ,在 MyDataset 中,主要获取图片的索引以及定义如何通过索引读取图片及其标签。但是要触发 MyDataset 去读取图片及其标签却是在数据加载器 DataLoder 中。本小节,将进行单步调试,学习图片是如 何从硬盘上流到模型的输入口的,并观察图片经历了哪些处理。
对应代码:
/Code/main_training/main.py
大体流程 :
1. main.py: train_data = MyDataset(txt_path=train_txt_path, ... --->
2. main.py: train_loader = DataLoader(dataset=train_data, ...) --->
3. main.py: for i, data in enumerate(train_loader, 0) --->
4. dataloder.py: class DataLoader(): def __iter__(self): return _DataLoaderIter(self) --->
5. dataloder.py: class _DataLoderIter(): def __next__(self): batch = self.collate_fn([self.dataset[i]for i in indices]) --->
6. tool.py: class MyDataset(): def __getitem__(): img = Image.open(fn).convert('RGB') --->
7. tool.py: class MyDataset(): img = self.transform(img) --->
8. main.py: inputs, labels = data inputs, labels = Variable(inputs), Variable(labels) outputs =45net(inputs)
一句话概括就是,从 MyDataset 来,到 MyDataset 去。
一开始通过 MyDataset 创建一个实例,在该实例中有路径,有读取图片的方法 ( 函数 )
然后需要 pytroch 的一系列规范化流程,在第 6 步中,才会调用 MyDataset 中的__getitem__() 函数,最终通过 Image.open() 读取图片数据。
然后对原始图片数据进行一系列预处理 (transform 中设置 ) ,最后回到 main.py ,对数据进行转换成 Variable 类型,最终成为模型的输入。
流程详细描述:
1. MyDataset 类中初始化 txt txt 中有图片路径和标签
2. 初始化 DataLoder 时,将 train_data 传入,从而使 DataLoder 拥有图片的路径
3. 在一个 iteration 进行时,才读取一个 batch 的图片数据。 enumerate() 函数会返回可迭代数据的一个 元素 ” 。在这里 data 是一个 batch 的图片数据和标签, data 是一个 list
4. class DataLoader() 中再调用 class _DataLoderIter()
5. _DataLoderiter() 类中会跳到 __next__(self) 函数,在该函数中会通过indices = next(self.sample_iter) 获取一个 batch indices
再通过batch = self.collate_fn([self.dataset[i] for i in indices]) 获取一个 batch 的数据
batch = self.collate_fn([self.dataset[i] for i in indices]) 中会调用 self.collate_fn 函数
6. self.collate_fn 中会调用 MyDataset 类中的 __getitem__() 函数,在 __getitem__() 中通过Image.open(fn).convert('RGB') 读取图片
7. 通过 Image.open(fn).convert('RGB') 读取图片之后,会对图片进行预处理,例如减均值,除以标准差,随机裁剪等等一系列提前设置好的操作。
具体 transform 的用法将用单独一小节介绍,最后返回 img label ,再通过 self.collate_fn 来拼接成一个 batch 。一个 batch 是一个 list ,有两个元素,第一个元素是图片数据,是一个 4D Tensor shape (64,3,32,32) ,第二个元素是标签 shape (64)
8. 将图片数据转换成 Variable 类型,然后称为模型真正的输入
inputs, labels = Variable(inputs), Variable(labels)
outputs = net(inputs)
通过了解图片从硬盘到模型的过程,我们可以更好的对数据做处理 ( 减均值,除以标准差,裁剪,翻转,放射变换等等 ) ,也可以灵活的为模型准备数据,最后总结两个需要注意 的地方。
1. 图片是通过 Image.open() 函数读取进来的,当涉及如下问题:
图片的通道顺序 (RGB ? BGR ?)
图片是 w*h*c c*w*h
像素值范围 [0-1] or [0-255]
就要查看 MyDataset() 类中 __getitem__( )下读取图片用的是什么方法
2. MyDataset() 类中 __getitem__( )函数中发现, PyTorch 做数据增强的方法是在原始图片上进行的,并覆盖原始图片,这一点需要注意。

你可能感兴趣的:(pytorch学习笔记)