数据读取机制Dataloader和Dataset和Transforms

人民币二分类模型

数据-模型-损失函数-优化器-迭代训练

  • 数据收集 img label
  • 数据划分 train valid test
  • 数据读取 Dataloader [sampler-生成索引 dataset-img,label]
  • 数据预处理 transforms

DataLoader

import torch

torch.utils.data.DataLoader()

  • 功能:构建可以可迭代的数据装载器
  • 参数:

    dataset Dataset类,决定数据从哪里读取和如何读取

    batchsize 批大小

    num_works 是否多进程读取数据

    shuffle每个epoch是否乱序

    drop_last 当样本数不能被batchsize整除时,是否舍弃最后一批数据

torch.utils.data.Dataloader(

dataset,

batch_size=1,

shuffle=False,

sampler=None,

batch_sampler=None,

num_workers=0,

drop_last=False

)


Epoch:训练样本都输入到模型中,称为一个Epoch

Iteration:一批样本输入到模型中,称之为一个Iteration

Batchsize:批大小,决定一个Epoch有多少个Iteration

例子:

样本总80,batchsize 8 ,1Epoch = 10 Iteration

1 Epoch = 10 Iteration ? drop_last = True

1 Epoch = 11 Iteration ? drop_last = False

# 功能Dataset抽象类,所有自定义的Dataset需要基础它,并且复写
# __getitem__()
# getitem:接收一个索引,返回一个样本

class Dataset(object):
    
    def __getitem__(self, index):
        raise NotImplementedError
    
    def __add__(self, other):
        return ConcatDataset([self,other])

import os
dir_test = os.path.join('..','..','data')
print(dir_test)
..\..\data

Transforms

常见的处理方法有:

  • 数据中心化
  • 数据标准化
  • 缩放
  • 剪裁
  • 旋转
  • 翻转
  • 填充
  • 噪声添加
  • 灰度变换
  • 线性变换
  • 仿射变换
  • 亮度、饱和度和对比度变换

transforms.Normalize(mean,std,inplace=False)

数据标准化,能加速模型收敛

数据增强方法

import os
import numpy as np
import torch
import random
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot  as plt

1.tranforms–CenterCrop

transforms.CenterCrop()中心裁剪

参数 size:所需要的尺寸

2.tranforms–RandomCrop

transforms.RandomCrop()随机裁剪

参数

size:所需要的尺寸

padding:填充大小

pad_if_need:如图像小于设定的size,则填充

padding_mode:填充模式,constant像素值fill设定,edge像素值由边缘图像决定,reflect镜像填充,最有个像素不镜像,symmertric镜像填充,最有一个像素镜像

fill:constant设置填充的像素值,如图像小于设定的size,则填充

3.tranforms–RandomResizedCrop

transforms.RandomResizedCrop()中心裁剪

参数 size:所需要的尺寸

参数 scale:随机裁剪面积比例(0.08,1)

参数 ratio:随机长宽比(3/4,4/3)

参数 interpolation:插值方法
* PIL.Image.NEAREST
* PIL.Image.BILINEAR
* PIL.Image.BICUBIC

4.tranforms–FiveCrop

transforms.FiveCrop()在图像的上下左右以及中心裁剪出尺寸为size的10张图片

参数 size:所需要的尺寸

5.tranforms–TenCrop

transforms.TenCrop()在图像的上下左右以及中心裁剪出尺寸为size的10张图片

参数 size:所需要的尺寸

参数 vertical_flip:是否翻转

6.tranforms–RandomHorizontalFlip

transforms.RandomHorizontalFlip()依概率水平翻转【左右】

参数 p:翻转概率

7.tranforms–RandomVerticalFlip

transforms.RandomVerticalFlip()依概率水平垂直【上下】

参数 p:翻转概率

8.tranforms–RandomRotation

transforms.RandomRotation()依概率旋转

参数 degresss:旋转角度

参数 resample:重采样

参数 expand:扩大图片,保持原图信息

参数 center:旋转中心

9.tranforms–Pad

transforms.Pad()对图片边缘填充

参数 padding:设置填充大小

参数 padding_mode:填充模式,分别是constant、edge、reflect、symmetric

参数 fill:为constant时填充像素值

10.tranforms–ColorJitter

transforms.ColorJitter()调节亮度、对比度、饱和度和色相

参数 brightness:调节亮度因子

参数 constrast:调节对比度参数

参数 saturation:调节饱和度

参数 hue:调节色相参数

11.tranforms–Grayscale

transforms.Grayscale()依概率图片转换为灰度

参数 num_output_channels:输出通道数智能设1或3

参数 p:转化为灰度的概率

11.tranforms–RandomGrayscale

transforms.RandomGrayscale()依概率图片转换为灰度

参数 num_output_channels:输出通道数智能设1或3

参数 p:转化为灰度的概率

12.tranforms–RandomAffine

transforms.RandomAffine()对图像进行仿射变换,仿射变换是二维的线性变换,有五种基本原子变换构成,分别是旋转、平移、缩放、错切、翻转

参数 degrees:旋转角度设置

参数 translate:平移区间设置 a设置宽width,b设置高height

参数 scale:缩放比例

参数 fill_color:填充颜色设置

参数 shear:错切角度设置,有水平错切和垂直错切。(a=X轴角度,b=Y轴角度)

参数 resample:重采样

13.tranforms–RandomErasing

transforms.RandomErasing()对图像进行随机遮挡

参数 p:执行遮挡的概率

参数 scale:遮挡区域面积

参数 p:遮挡区域长宽比

参数 p:设置遮挡区域的像素值

14.tranforms–Lambda

transforms.Lambda()用户自定义Lambda方法

表达式:

lambda[arg1[,arg2,…,argn]]:expression

transforms.Tencrop(200,vertical_filp=True)
transforms.Lambda(lambda crops:torch.stack([transforms.Totensor()(crop) for crop in crops])) 

15.tranforms–RandomChoice

transforms.RandomChoice()随机选择一个transforms方法

transforms.RandomChoice([方法1,方法2,方法3])

16.tranforms–RandomApply

transforms.RandomApply()依概率执行一组transforms方法

transforms.RandomChoice([方法1,方法2,方法3],p=0.5)

17.tranforms–RandomOrder

transforms.RandomOrder()对一组transforms操作打乱顺序

transforms.RandomChoice([方法1,方法2,方法3])

18 自定义transforms方法

  • 1.仅接收一个参数,返回一个参数
  • 2.注意上下游的输出与输入
class Compose(object):
    def __call__(self,img):
        for t in self.transforms:
            img = t(img)
        return img

通过类实现多参数传入

class YourTransforms(object):
    def __init__(self,...):
        ...
    def __call__(self,img):
        ...
        return img
  • 椒盐噪声:又叫做脉冲噪声,是一种随机出现的白点或者黑点,白点叫盐噪声,黑点叫椒噪声
  • 信噪比(Signal-Noise Rate,SNR)是衡量噪声的比例,图像中为图像像素的占比

一个例子

Class AddPepperNoise(object):
    """
    Args:
        snr(float):signal noise rate
        p(float):概率值,依概率执行操作
    """
    def __init__(self,snr,p=0.9):
        assert isinstance(snr,float) or (isinstance(p,float))
        self.snr = snr
        self.p = p
    def __call__(self,img):
        
        if random.uniform(0,1) < self.p:
            img_ = np.array(img).copy()
            h ,w ,c = img_.shape
            signal_pct = self.snr
            noise_pct = (1 - self.snr)
            mask = np.random.choice((0,1,2),size = (h,w,1),p=[signal_pct,noise_pct/2.,noise_pct/2])
            mask = np.repeat(mask,c,axis=2)
            img_[mask == 1] = 255 # 盐噪声
            img_[mask == 1] = 0 # 椒噪声
            return Image.fromarray(img_.astype('utf-8')).convert('RGB')
        else:
            return img

你可能感兴趣的:(Pytorch,pytorch,深度学习,python)