pytorch学习笔记-基础-数据的加载和预处理

PyTorch 基础 :数据的加载和预处理

PyTorch通过torch.utils.data对一般常用的数据加载进行了封装,可以很容易地实现多线程数据预读和批量加载。 并且torchvision已经预先实现了常用图像数据集,包括前面使用过的CIFAR-10,ImageNet、COCO、MNIST、LSUN等数据集,可通过torchvision.datasets方便的调用

import torch
torch.__version__
'1.2.0'

Dataset

Dataset是一个抽象类, 为了能够方便的读取,需要将要使用的数据包装为Dataset类。 自定义的Dataset需要继承它并且实现两个成员方法:

__getitem__() 该方法定义用索引(0 到 len(self))获取一条数据或一个样本
__len__() 该方法返回数据集的总长度

下面我们使用kaggle上的一个竞赛bluebook for bulldozers自定义一个数据集,为了方便介绍,我们使用里面的数据字典来做说明(因为条数少)

from torch.utils.data import Dataset
import pandas as pd 
class BulldozerDateset(Dataset): #继承Dataset类
    """数据集演示"""
    def __init__(self,csv_file):
        self.df = pd.read_csv(csv_file)
    def __len__(self):
        #返回df的长度
        return len(self.df)
    def __getitem__(self,idx):
        # 根据idx返回一行数据
        return self.df.iloc[idx]['x1']  #读取下标为idx的一行数据。可以在后面加 .列名 的方式读取属于某一列的数据,否则就会读取这一整行
data = BulldozerDateset(r'data/datatest_1.csv')
# 实现了 __len__ 方法所以可以直接使用len()获取数据总数
len(data)
25
#用索引可以直接访问对应数据,对应 __getitem__方法:
for i in range(len(data)):
    print(data[i])
#     x1 , x2 = data[i]
#     print("x1 = ",x1,"x2 = ",x2)
0.232991543
0.449915356
0.840922298
0.20727367
0.541869015
0.36092917399999996
0.668949803
0.15037799400000001
0.898436358
0.302533521
0.286306281
0.8299904779999999
0.69587861
0.848728081
0.527168802
0.231401747
0.5379995989999999
0.6874250009999999
0.5887086970000001
0.252984848
0.067723107
0.220923151
0.49844191299999996
0.886522632
0.5958730729999999

Dataloader

DataLoader为我们提供了对Dataset的读取操作,常用参数有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程),下面做一个简单的操作

dl = torch.utils.data.DataLoader(data, batch_size= 5 ,shuffle= True , num_workers= 0)
print(dl)

DataLoader返回的是一个可迭代对象,我们可以使用迭代器分次获取数据

idata = iter(dl)
print(next(idata))
tensor([0.2073, 0.6689, 0.2330, 0.5272, 0.3025])

对Dataloader常见的用法是使用for循环对其进行遍历

for i , value in enumerate(dl):
    print(i, ": ",value)
    
0 :  tensor([0.8984, 0.6689, 0.0677, 0.5380, 0.2530])
1 :  tensor([0.8300, 0.5959, 0.4984, 0.4499, 0.3025])
2 :  tensor([0.5887, 0.2863, 0.2330, 0.8409, 0.8865])
3 :  tensor([0.6874, 0.1504, 0.2209, 0.8487, 0.6959])
4 :  tensor([0.2073, 0.5272, 0.5419, 0.3609, 0.2314])

torchvision包

orchvision 是PyTorch中专门用来处理图像的库,PyTorch官网的安装教程中最后的pip install torchvision 就是安装这个包。
torchvision.datasets

torchvision.datasets 可以理解为PyTorch团队自定义的dataset,这些dataset帮我们提前处理好了很多的图片数据集,我们拿来就可以直接使用:

MNIST
COCO
Captions
Detection
LSUN
ImageFolder
Imagenet-12
CIFAR
STL10
SVHN
PhotoTour 我们可以直接使用,示例如下:
import torchvision.datasets as datasets
train_set = datasets.MNIST(
    root ='./data',  #表示加载数据的目录
    train = True ,   #表示是否加载数据集的训练集,false的话就是测试集
    download = True, #表示是否自动下载数据集
    transform = None , #表示是否对数据集预处理,如归一化。none为不进行预处理
    )
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


 99%|█████████████████████████████████████████████████████████████████████████████▌| 9.85M/9.91M [00:20<00:00, 707kB/s]

Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz



0.00B [00:00, ?B/s]
  0%|                                                                                      | 0.00/28.9k [00:00

torchvision.models

torchvision不仅提供了常用图片数据集,还提供了训练好的模型,可以加载之后,直接使用,或者在进行迁移学习 torchvision.models模块的 子模块中包含以下模型结构。

AlexNet
VGG
ResNet
SqueezeNet
DenseNet
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to C:\Users\Lowry/.cache\torch\checkpoints\resnet18-5c106cde.pth

  0%|                                                                                      | 0.00/44.7M [00:00

torchvision.transforms

transforms 模块提供了一般的图像转换操作类,用作数据处理和数据增强

from torchvision import transforms as transforms
transform = transforms.Compose([
    transforms.RandomCrop(32,padding=4), #先在四周填充0,把图像随机材成32 * 32
    transforms.RandomHorizontalFlip() , #图像一半概率翻转,一般概率不翻转
    transforms.RandomRotation((-45,45)) ,#随机旋转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)) ##R,G,B每层的归一化用到的均值和方差    
])

你可能感兴趣的:(Pytorch)