Pytorch学习笔记-02 数据读取与处理

02 数据读取与处理

文章目录

  • 02 数据读取与处理
    • DataLoader 与 Dataset
      • torch.utils.data.DataLoader
      • torch.utils.data.Dataset
      • 数据读取流程
    • transforms
    • transforms图像增强
      • 数据增强
      • transforms——Crop
      • transforms——Flip and Rotation
      • 图像变换
      • 自定义transforms
      • 总结:
      • 数据增强实战
      • 总结:
      • 数据增强实战

DataLoader 与 Dataset

深度学习模型训练一般流程

Pytorch学习笔记-02 数据读取与处理_第1张图片

torch.utils.data.DataLoader

功能:构建可迭代的数据装载器

  • dataset: Dataset类,决定数据从哪读取及如何读取
  • batchsize : 批大小
  • num_works: 是否多进程读取数据
  • shuffle: 每个epoch是否乱序
  • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

概念辨析:

Epoch: 所有训练样本都已输入到模型中,称为一个Epoch

Iteration:一批样本输入到模型中,称之为一个Iteration

Batchsize:批大小,决定一个Epoch有多少个Iteration

e.g.:


class MyDataset(Dataset):
    def __init__(self, data_dir, transforms=None):
        super().__init__()
        self.Label = {
     '1':0, '100',1}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.data_info)

torch.utils.data.Dataset

功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()

getitem :接收一个索引,返回一个样本

数据读取流程

Pytorch学习笔记-02 数据读取与处理_第2张图片

transforms

torchvision.transforms : 常用的图像预处理方法

  • 数据中心化
  • 数据标准化
  • 缩放
  • 裁剪
  • 旋转
  • 翻转
  • 填充
  • 噪声添加
  • 灰度变换
  • 线性变换
  • 仿射变换
  • 亮度、饱和度及对比度变换

作用位置:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WEO1DU2s-1637072332139)(https://i.loli.net/2021/11/16/9c4xH5n6JegIPjS.png)]

transforms图像增强

数据增强

数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2IQ1OkSu-1637072332143)(https://i.loli.net/2021/11/16/OxMqR7fTFnL8VQs.png)]

transforms——Crop

  1. transforms.CenterCrop

    功能:从图像中心裁剪图片

  2. transforms.RandomCrop
    功能:从图片中随机裁剪出尺寸为size的图片

  3. RandomResizedCrop
    功能:随机大小、长宽比裁剪图片

  4. FiveCrop

  5. TenCrop

    功能:在图像的上下左右以及中心裁剪出尺寸为size的5张图片,TenCrop对这5张图片进行水平或者垂直镜像获得10张图片

transforms——Flip and Rotation

  • RandomHorizontalFlip

  • RandomVerticalFlip

    • 功能:依概率水平(左右)或垂直(上下)翻转图片
  • RandomRotation

    功能:随机旋转图片
    **degrees:**旋转角度
    当为a时,在(-a,a)之间选择旋转角度
    当为(a, b)时,在(a, b)之间选择旋转角度
    **resample:**重采样方法
    **expand:**是否扩大图片,以保持原图信息

图像变换

  1. Pad
    功能:对图片边缘进行填充
transforms.Pad(padding,
fill=0,
padding_mode='constant')
  1. ColorJitter
    功能:调整亮度、对比度、饱和度和色相
transforms.ColorJitter(brightness=0,
contrast=0,
saturation=0,
hue=0)
  1. Grayscale

  2. RandomGrayscale

    功能:依概率将图片转换为灰度图

    RandomGrayscale(num_output_channels,
    p=0.1)
    Grayscale(num_output_channels)
    
  3. RandomAffine
    功能:对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转

    RandomAffine(degrees,
    translate=None,
    scale=None,
    shear=None,
    resample=False,
    fillcolor=0)
    
  4. RandomErasing
    功能:对图像进行随机遮挡

  5. transforms.Lambda

功能:用户自定义lambda方法

  1. transforms.RandomChoice
transforms.RandomChoice([transforms1, transforms2, transforms3])

功能:从一系列transforms方法中随机挑选一个

  1. transforms.RandomApply
transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)

功能:依据概率执行一组transforms操作

  1. transforms.RandomOrder
transforms.RandomOrder([transforms1, transforms2, transforms3])

功能:对一组transforms操作打乱顺序

自定义transforms

自定义transforms要素:

  • 仅接收一个参数,返回一个参数
  • 注意上下游的输出与输入
class YourTransforms(object):
	def __init__(self, ...):
	
	def __call__(self, img):

	return img

总结:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-z9YuW2g3-1637072332145)(https://i.loli.net/2021/11/16/nC9JMbkfG1dZQeT.png)]

数据增强实战

原则:让训练集与测试集更接近
ansforms要素:

  • 仅接收一个参数,返回一个参数
  • 注意上下游的输出与输入
class YourTransforms(object):
	def __init__(self, ...):
	
	def __call__(self, img):

	return img

总结:

[外链图片转存中…(img-z9YuW2g3-1637072332145)]

数据增强实战

原则:让训练集与测试集更接近

你可能感兴趣的:(Pytorch学习,机器学习,python,神经网络)