pytorch入门基础知识小白必备,强推,超详细!!!(四)pytorch如何加载数据和进行预处理

这一篇博客是关于如何在pytoch里加载训练数据到网络中的,同志们来一起学习吧~
声明一下呀,源代码是GitHub的大神的,我只是搬运工和修理工
原文链接:https://github.com/yehaizi1995/pytorch-handbook/blob/master/chapter2/2.1.4-pytorch-basics-data-loader.ipynb

文章目录

  • 数据加载和预处理
    • 数据加载
  • 预处理
    • torchvision.models
    • torchvision.transforms

数据加载和预处理

数据加载

PyTorch通过torch.utils.data对一般常用的数据加载进行了封装,可以很容易地实现多线程数据预读和批量加载。
并且torchvision已经预先实现了常用图像数据集,包括前面使用过的CIFAR-10,ImageNet、COCO、MNIST、LSUN等数据集
可通过torchvision.datasets方便的调用
pandas是python数据分析库 pandas中主要有两种数据结构:Series 和 DataFrame

Dataset是一个抽象类, 为了能够方便的读取,需要将要使用的数据包装为Dataset类。 自定义的Dataset需要继承它并且实现两个成员方法:
1、getitem() 该方法定义用索引(0 到 len(self))获取一条数据或一个样本
2、len() 该方法返回数据集的总长度

如下面例子所示:

#!/usr/bin/env python
# -*- coding:utf-8 -*- 
# Author: yehaizi time:2019/8/14:15:41


import torch
from torch.utils.data import Dataset
import pandas as pd
# 定义一个数据集
class BulldozerDataset(Dataset):
    def __init__(self, csv_file):
        self.df=pd.read_csv(csv_file)      # 实现初始化方法,在初始化的时候将数据读载入

# 返回df的长度
    def __len__(self):
        return len(self.df)

# 根据 idx 返回一行数据
    def __getitem__(self, idx):
        return self.df.iloc[idx].SalePrice
# median_benchmark.csv这个是一个数据文件,嗯,我没找到原作者提供的文件,所以就暂时不用理他,我们看这份代码的主要思想和具体步骤就好
ds_demo= BulldozerDataset('median_benchmark.csv')

# 我们可以直接使用如下命令查看数据集数据
# 实现了 __len__ 方法所以可以直接使用len获取数据总数
print(len(ds_demo))

# 用索引可以直接访问对应的数据, 对应 __getitem__ 方法
print(ds_demo[0])

# 下面我们使用官方提供的数据载入器,读取数据
# DataLoader为我们提供了对Dataset的读取操作,常用参数有:batch_size(每个batch的大小), shuffle:是否进行shuffle操作,
# shuffle() 方法将序列的所有元素随机排序
# num_workers(加载数据的时候使用几个子进程),下面做一个简单的操作


dl = torch.utils.data.DataLoader(ds_demo, batch_size=10, shuffle=True, num_workers=0)

#  DataLoader返回的是一个可迭代对象,我们可以使用迭代器分次获取数据
idata=iter(dl)
print(next(idata))

#  常见的用法是使用for循环对其进行遍历

for i, data in enumerate(dl):
    print(i,data)
    # 为了节约空间, 这里只循环一遍
    break

预处理

torchvision包是pytorch里面的一个图像处理包,torchvision.datasets里面具有许多数据集,比如手写数字MINIST,Imagenet,并且已经处理好了,可以拿来直接用。比如下面:

import torchvision.datasets as datasets
trainset = datasets.MNIST(root='./data', # 表示 MNIST 数据的加载的目录
                                      train=True,  # 表示是否加载数据库的训练集,false的时候加载测试集
                                      download=True, # 表示是否自动下载 MNIST 数据集
                                      transform=None) # 表示是否需要对数据进行预处理,none为不进行预处理

torchvision.models

torchvision不仅提供了常用图片数据集,还提供了训练好的模型,可以加载之后,直接使用。比如vgg,alexnet,resnet等

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)

torchvision.transforms

transforms 模块提供了一般的图像转换操作类,用作数据处理和数据增强。

from torchvision import transforms as transforms
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  #先四周填充0,在把图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),  #图像一半的概率翻转,一半的概率不翻转
    transforms.RandomRotation((-45,45)), #随机旋转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)), #R,G,B每层的归一化用到的均值和方差
])

终于写完了~ 觉得有用请点赞关注一波吧,O(∩_∩)O谢谢

你可能感兴趣的:(pytorch学习系列文章)