Pytorch框架训练时的数据预处理、数据集以及导入、加载数据

前言

目前刚刚接触深度学习方向,也在学习pytorch框架。本文是我在尝试相关网络的pytorch框架时遇到的一些问题以及认为有必要总结一下的内容。

此内容主要参考了以下博客:https://blog.csdn.net/m0_37867091/article/details/107150142​​​​​​

数据预处理

在网络开始训练之前,为了使训练更好的进行,我们需要对训练进行一些预处理操作。在pytorch中是由torchvision.transforms来操作的,torchvision.transforms中包含了一些常见的操作。以下是目前见到常用的几种:

transforms.Compose可以用来将多种操作集合到一起,打包了多个图片处理的方法,如:

transforms.Compose([ 
transforms.CenterCrop(10), 
transforms.ToTensor(), 
]) 

transforms.ToTensor() 将shape(H, W, C)nump.ndarrayimg转为shape(C, H, W)tensor,其将每一个数值归一化到[0,1],其归一化方法比较简单,直接除以255即可。

transforms.Normalize()其作用就是先将输入归一化到(0,1),再使用公式"(x-mean)/std",将每个元素分布到(-1,1)。

torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的。其构成如下:

torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
torchvision.utils: 其他的一些有用的方法。


原文链接:https://blog.csdn.net/wangkaidehao/article/details/104520022/

数据集

各种网络模型的训练都离不开数据集的支持,当我们针对某个数据集时,往往是两种导入方法:1.pytorch内置的torchvision.datasets函数进行在线导入相关的数据集

Pytorch框架训练时的数据预处理、数据集以及导入、加载数据_第1张图片

2.导入个人制作的数据集

参考:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/data_set/README.md

个人的数据集需要划分为训练集、测试集两部分,下面是对数据集进行分类的脚本:

import os
from shutil import copy, rmtree
import random


def mk_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在,则先删除原文件夹在重新创建
        rmtree(file_path)
    os.makedirs(file_path)


def main():
    # 保证随机可复现
    random.seed(0)

    # 将数据集中10%的数据划分到验证集中
    split_rate = 0.1

    # 指向你解压后的flower_photos文件夹
    cwd = os.getcwd()
    data_root = os.path.join(cwd, "flower_data")
    origin_flower_path = os.path.join(data_root, "flower_photos")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)

    flower_class = [cla for cla in os.listdir(origin_flower_path)
                    if os.path.isdir(os.path.join(origin_flower_path, cla))]

    # 建立保存训练集的文件夹
    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(train_root, cla))

    # 建立保存验证集的文件夹
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(val_root, cla))

    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path, cla)
        images = os.listdir(cla_path)
        num = len(images)
        # 随机采样验证集的索引
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                # 将分配至验证集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                # 将分配至训练集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()

    print("processing done!")


if __name__ == '__main__':
    main()

其中文件夹的名称根据自己的数据集进行替换。

导入、加载数据

对于在torchvision图形库中在线导入的数据集代码如下:

# 导入训练集
train_set = torchvision.datasets.CIFAR10(root='./data',      # 数据集存放目录
                                         train=True,         # 表示是数据集中的训练集
                                        download=True,       # 第一次运行时为True,下载数据集,下载完成后改为False
                                        transform=transform) # 预处理过程
# 加载训练集    
train_loader = torch.utils.data.DataLoader(train_set,       # 导入的训练集
                                           batch_size=50, # 每批训练的样本数
                                          shuffle=False,  # 是否打乱训练集
                                          num_workers=0)  # num_workers在windows下设置为0

对于个人划分的数据集代码如下:

# 获取图像数据集的路径
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))          # get data root path 
image_path = data_root + "/data_set/flower_data/"                           # flower data_set path

# 导入训练集并进行预处理
train_dataset = datasets.ImageFolder(root=image_path + "/train",        
                                     transform=data_transform["train"])
train_num = len(train_dataset)

# 按batch_size分批次加载训练集
train_loader = torch.utils.data.DataLoader(train_dataset,    # 导入的训练集
                                           batch_size=32,     # 每批训练的样本数
                                           shuffle=True,    # 是否打乱训练集
                                           num_workers=0)    # 使用线程数,在windows下设置为0
 

你可能感兴趣的:(pytorch,深度学习,机器学习)