Datawhale 零基础入门CV赛事学习笔记-Task2 数据读取与数据扩增

Datawhale 零基础入门CV赛事学习笔记-Task2 数据读取与数据扩增

1 数据读取

  在Pytorch中实现图像数据读取主要基于两个基类:Dataset和DataLoader,Dataset主要是通过索引加载图片并进行相应的处理,而DataLoader则进行图片的批量打包(batch)。

1.1 Dataset

  torch.utils.data.Dataset() :Dataset抽象类, 所有自定义的Dataset都需要继承它,并且必须复写__getitem__()这个类方法。__getitem__方法的是Dataset的核心,作用是接收一个索引, 返回一个样本, 参数里面接收index,然后我们需要编写究竟如何根据这个索引去读取我们的数据部分。
  在Dataset中还可以定义transforms,可以通过transforms进行图像的预处理。

import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中类别10为数字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]

data = SVHNDataset(train_path, train_label,
          transforms.Compose([
              # 缩放到固定尺寸
              transforms.Resize((64, 128)),

              # 随机颜色变换
              transforms.ColorJitter(0.2, 0.2, 0.2),

              # 加入随机旋转
              transforms.RandomRotation(5),

              # 将图片转换为pytorch 的tesntor
              transforms.ToTensor(),

              # 对图像像素进行归一化
              transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ]))

  SVHNDataset 继承了 Dataset 基类,并且重写了__getitem__和__len__方法。
  在__getitem__方法中,我们通过索引进行图片的读取,并且使用填充方法对不足6位的图像字符进行填充。
  在__init__方法中,我们需要传入单个图片的路径和图片中字符对应的标签,并且我们可以手动传入相对应的transforms方法。

1.2 DataLoader

  torch.utils.data.DataLoader() :构建可迭代的数据装载器,我们在训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据。

  DataLoader的主要主要参数有5个:

  • dataset:Dataset类, 数据的读取和图片预处理
  • batch_size:每批图片数量,默认为1
  • num_workers:是否多进程读取机制,默认为0
  • shuffle:每个epoch是否乱序,默认为False
  • drop_last:当样本数不能被batch_size整除时, 是否舍弃最后一批数据。默认为False
import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中类别10为数字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]

train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                   transforms.Compose([
                       transforms.Resize((64, 128)),
                       transforms.ColorJitter(0.3, 0.3, 0.2),
                       transforms.RandomRotation(5),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])), 
    batch_size=10, # 每批样本个数
    shuffle=False, # 是否打乱顺序
    num_workers=10, # 读取的线程个数
)

for data in train_loader:
    break

  在加入DataLoder后,数据按照批次获取,每批次调用Dataset读取单个样本进行拼接。此时data的格式为:
torch.Size([10, 3, 64, 128]), torch.Size([10, 6])。
  前者为图像文件,为batchsize * chanel * height * width次序;后者为字符标签。

2 数据扩增

2.1 transforms

  在torchvision中,有三个主要的模块,分别是transforms、datasets和models。transforms 是torchvision计算机视觉工具包最常用的图像预处理方法。
  transforms包括实现图像裁剪、图像的旋转和图像变换,并通过transforms方法实现更多的图像操作,包括但不限于:数据中心化,数据标准化,缩放,裁剪,旋转,翻转,填充,噪声添加,灰度变换,线性变换,仿射变换,亮度、饱和度及对比度变换等等。

2.2 图像裁剪

  • transforms.RandomCrop: 对图像进行裁剪
  • transforms.CenterCrop(size):图像中心裁剪图片, size是所需裁剪的图片尺寸,如果比原始图像大了, 会默认填充0。
  • transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant):从图片中位置随机裁剪出尺寸为size的图片,size是尺寸大小,padding设置填充大小,pad_if_need: 若图像小于设定的size, 则填充。
  • padding_mode表示填充模型, 有4种:constant像素值由fill设定,edge像素值由图像边缘像素设定,reflect镜像填充,symmetric也是镜像填充。
  • transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3/4, 4/3), interpolation):随机大小,长宽比裁剪图片。 scale表示随机裁剪面积比例,ratio随机长宽比,interpolation表示插值方法。
  • FiveCrop, TenCrop:在图像的上下左右及中心裁剪出尺寸为size的5张图片,后者还在这5张图片的基础上再水平或者垂直镜像得到10张图片。

2.3 图像的翻转和旋转

  • RandomHorizontalFlip(p=0.5), RandomVerticalFlip(p=0.5):依概率水平或者垂直翻转图片,p表示翻转概率
  • RandomRotation(degrees, resample=False, expand=False, center=None):随机旋转图片,degrees表示旋转角度,resample表示重采样方法,expand表示是否扩大图片,以保持原图信息。

2.4 图像变换

  • transforms.Compose: 将一系列的transforms方法进行有序的组合包装,具体实现的时候,依次的用包装的方法对图像进行操作。
  • transforms.Resize: 改变图像大小
  • transforms.ToTensor: 将图像转换成张量,同时会进行归一化的一个操作,将张量的值从0-255转到0-1
  • transforms.Normalize: 将数据进行标准化
  • transforms.Pad(padding, fill=0, padding_mode=‘constant’): 对图片边缘进行填充
  • transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):调整亮度、对比度、饱和度和色相。brightness是亮度+ 调节因子,contrast对比度参数,saturation饱和度参数,hue是色相因子。
  • transfor.RandomGrayscale(num_output_channels, p=0.1):依概率将图片转换为灰度图,第一个参数是通道数,只能1或3,p是概率值,转换为灰度图像的概率
  • transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):对图像进行仿射变换,反射变换是二维的线性变换 由五中基本原子变换构成,分别是旋转,平移,缩放,错切和翻转。 degrees表示旋转角度,+ translate表示平移区间设置,scale表示缩放比例,fill_color填充颜色设置,shear表示错切
  • transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): 对图像进行随机遮挡,p概率值,scale遮挡区域的面积,ratio遮挡区域长宽比,value遮挡像素。 随机遮挡有利于模型识别被遮挡的图片。这个是对张量进行操作,所以需要先转成张量才能做

你可能感兴趣的:(Datawhale 零基础入门CV赛事学习笔记-Task2 数据读取与数据扩增)