【pytorch学习笔记四】数据的加载(Dataset,Dataloader)和预处理(torchvision 详细介绍)

【pytorch学习笔记四】数据的加载(Dataset,Dataloader)和预处理(torchvision 详细介绍)_第1张图片

文章目录

    • 1.自定义数据集Dataset
    • 2.读取数据集Dataloader
    • 3.torchvision 包
      • 3.1torchvision.datasets
      • 3.2torchvision.models
      • 3.3torchvision.transforms
      • 3.4 常见的torchvision.transforms图片操作

PyTorch通过torch.utils.data对一般常用的数据加载进行了封装,可以很容易地实现多线程数据预读和批量加载。 并且torchvision已经预先实现了常用图像数据集,包括前面使用过的CIFAR-10,ImageNet、COCO、MNIST、LSUN等数据集,可通过torchvision.datasets方便的调用。

1.自定义数据集Dataset

为了能够方便的读取,需要将要使用的数据包装为Dataset类。 自定义的Dataset需要继承它并且实现两个成员方法: 1. __getitem__() 该方法定义用索引(0len(self))获取一条数据或一个样本 2. __len__() 该方法返回数据集的总长度

import torch
from torch.utils.data import Dataset
import pandas as pd
#自定义一个数据集类,继承Dataset
class BluebookDataset(Dataset):
    '''数据集演示'''
    def __init__(self,csv_file):
        '''初始化时将数据载入'''
        self.df=pd.read_csv(csv_file)
    def __len__(self):
        return len(self.df)#获取长度
    def __getitem__(self,idx):
        #iloc[ : , : ],冒号前面的取行数,后面的取列数,左闭右开原则.
        return self.df.iloc[idx].SalePrice #读取第idx行,SalePrice列的数据
ds_demo = BluebookDataset('F:\Desktop\median_benchmark.csv') #先下载对应的.CSV文件
print(len(ds_demo))
ds_demo[0]

结果:

11573
24000.0

2.读取数据集Dataloader

DataLoader为我们提供了对Dataset的读取操作,常用参数有:batch_size(每个batch的大小)、 shuffle(是否进行shuffle操作)、 num_workers(加载数据的时候使用几个子进程)。

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
                            batch_sampler=None, num_workers=0, collate_fn=None, 
                            pin_memory=False, drop_last=False, timeout=0, 
                            worker_init_fn=None, multiprocessing_context=None)

参数解释:

  • dataset: Dataset 类,决定数据从哪里读取以及如何读取
  • batchsize: 批大小
  • num_works:num_works: 是否多进程读取数据
  • sheuffle: 每个 epoch 是否乱序
  • drop_last: 当样本数不能被 batchsize 整除时,是否舍弃最后一批数据
#返回一个可迭代的对象
dl = torch.utils.data.DataLoader(ds_demo, batch_size=10, shuffle=True, num_workers=0)
for i, data in enumerate(dl):
    print(i,data)
    # 为了节约空间,这里只循环三遍
    if(i==2):
        break 

结果:

0 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
        24000.], dtype=torch.float64)
1 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
        24000.], dtype=torch.float64)
2 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
        24000.], dtype=torch.float64)

3.torchvision 包

3.1torchvision.datasets

torchvision.datasets 为PyTorch自定义的dataset,包括: - MNIST - COCO - Captions - Detection - LSUN - ImageFolder - Imagenet-12 - CIFAR - STL10 - SVHN - PhotoTour。

import torchvision.datasets as datasets
trainset = datasets.MNIST(root='./data', # 表示 MNIST 数据的加载的目录
                          train=True,  # 表示是否加载数据库的训练集,false的时候加载测试集
                          download=True, # 表示是否自动下载 MNIST 数据集
                          transform=None) # 表示是否需要对数据进行预处理,none为不进行预处理

加载时遇到的问题:

ImportError: The _imaging extension was built for another version of Pillow or PIL:
Core version: "9.2.0"
Pillow version: 9.2.0

解决方法:

重新安装一下pillow就可以了

pip uninstall Pillow
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pillow

下载成功

【pytorch学习笔记四】数据的加载(Dataset,Dataloader)和预处理(torchvision 详细介绍)_第2张图片

3.2torchvision.models

下载常用的模型,包括:- AlexNet - VGG - ResNet - SqueezeNet - DenseNet。

#我们直接可以使用训练好的模型,当然这个与datasets相同,都是需要从服务器下载的
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)

3.3torchvision.transforms

transforms 模块提供了一般的图像转换操作类,用作数据处理和数据增强。

from torchvision import transforms as transforms
transform = transforms.Compose([ #串联多个图片变换的操作,即想要执行的transform操作。
    transforms.RandomCrop(32, padding=4),  #先四周填充4层0,在把图像随机裁剪成32*32像素大小
    transforms.RandomHorizontalFlip(),  #图像一半的概率翻转,一半的概率不翻转
    transforms.RandomRotation((-45,45)), #随机旋转
    transforms.ToTensor(), #把图像转换为Tensor
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.229, 0.224, 0.225),inplace=False), #R,G,B每层的归一化用到的均值和方差,把图片3个通道中的数据整理到[-1, 1]区间,可以加快模型的收敛,mean和std可自己设定.
])

3.4 常见的torchvision.transforms图片操作

操作 功能
transforms.CenterCrop(size)#size为裁剪大小,超过原图大小自动补0 中心裁剪
transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode=‘constant’) #fill填充的像素大小,padding_mode:填充的模式(constant:填充fill设定的值, edge:填充边界的值, reflect or symmetric) 随机裁剪
transforms.RandomResizedCrop(size, scale=(0.01, 1.0), ratio=(0.75, 1.4), interpolation=2)#scale:面积随机在(0.01, 1.0)之间的比例缩放,ratio:长宽比随机在(0.75, 1.4)之间选取。 随机大小、随机宽高比裁剪图片
transforms.FiveCrop(size, vertical_flip=False)#size:最后裁剪的图片尺寸,vertical_flip:是否翻转。最后的 tensor 形状是 [5crops, c, h, w] 在图像的上下左右以及中心裁剪出尺寸为 size 的 5 张图片
transforms.RandomVerticalFlip(p=1) #p为翻转概率 水平或者垂直方向翻转图片
transforms.RandomRotation(degrees, resample=False, expand=False, center=None, fill=None) #degrees:旋转角度如(-45,60),resample是否重采样,expand:是否扩大矩形框,会改变图片的尺寸,center:旋转中心,默认是图片的中心。 随机旋转
transforms.Pad(padding=(16,12), fill=0, padding_mode=‘constant’)#padding:填充的大小,(16,12)表示上下填充16,左右填充12. 图片填充
transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)#参数分别为亮度,对比度,饱和度和色相。 调整亮度、对比度、饱和度、色相。
transforms.Grayscale(p=0.1, num_output_channels=3) #p:转为灰度图的概率,num_output_channels:输出通道数。 转灰度图
transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0) #分别为旋转,平移,缩放,填充颜色,错切角及采样设置( NEAREST、BILINEAR、BICUBIC。)。 仿射变换
transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False) 图像随机遮挡
transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5) 根据概率执行或不执行一组 transforms 操作
transforms.RandomChoice([transforms1, transforms2, transforms3]) 随机选一个执行
transforms.RandomOrder([transforms1, transforms2, transforms3]) 打乱顺序执行一组 transforms 操作

参考资料:
https://handbook.pytorch.wiki/chapter2/2.1.4-pytorch-basics-data-loader.html

未完待续!

欢迎关注个人公众号【智能建造小硕】(分享计算机编程、人工智能、智能建造、日常学习和科研经验等,欢迎大家关注交流。)

你可能感兴趣的:(Pytorch学习笔记,pytorch,学习,深度学习)