Pytorch学习——数据读取、数据预处理

一、DataLoader和DataSet

1.1 DataLoader

torch.utils.data.DataLoader   #构建可迭代的数据装载器
DataLoader()参数 :

  • dataset: Dataset类,决定数据从哪读取 及如何读取
  • batchsize : 批大小
  • num_works: 是否多进程读取数据
  • shuffle: 每个epoch是否乱序
  • drop_last:当样本数不能被batchsize整 除时,是否舍弃最后一批数据
  • Epoch: 所有训练样本都已输入到模型中,称为一个Epoch
  • Iteration:一批样本输入到模型中,称之为一个Iteration
  • Batchsize:批大小,决定一个Epoch有多少个Iteration 
  • 例如:样本总数:80, Batchsize:8      1 Epoch = 10 Iteration

1.2 DataSet

torch.utils.data.Dataset   

Dataset抽象类,所有自定义的 Dataset需要继承它,并且复写__getitem__()
getitem :接收一个索引,返回一个样本

例如:

from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):

    #初始化。一般写些该类的全局变量,为后边函数提供变量
    def __init__(self, data_dir, label_dir):
        #不同的数据格式,此处的处理方法不同
        self.data_dir = data_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.data_dir, self.label_dir)
        self.img_path = os.listdir(self.path)  #获取路径中的每一张图片
        
    def __getitem__(self, item):
        img_name = self.img_path[item]
        img_item_path = os.path.join(self.data_dir, self.label_dir, img_name)   #获取图片路径
        img = Image.open(img_item_path)
        label = self.label_dir
        
        return img, label
    
    def __len__(self):
        return len(self.img_path)

二、transforms

2.1 transforms

torchvision:计算机视觉工具包

torchvision.transforms : 常用的图像预处理方法——数据中心化 • 数据标准化 • 缩放 • 裁剪 • 旋转 • 翻转 • 填充 • 噪声添加 • 灰度变换 • 线性变换 • 仿射变换 • 亮度、饱和度及对比度变换

torchvision.datasets : 常用数据集的dataset实现,MNIST,CIFAR-10,ImageNet等

torchvision.model : 常用的模型预训练,AlexNet,VGG, ResNet,GoogLeNet等

transforms.Normalize(mean, std, inplace=False)  #逐channel的对图像进行标准化为(-1,1)

  • output = (input - mean) / std
  • mean:各通道的均值
  • std:各通道的标准差
  • inplace:是否原地操作

你可能感兴趣的:(深度学习,深度学习)