【Pytorch】--- 数据的读取和操作(Dataset, DataLoader)

前言

Pytorch 中比较重要的是对数据的处理,其中,进行数据读取的一般有三个类:

  • Dataset
  • DataLoader

其中,这三是一个依次封装的关系:“Dataset被封装进DataLoaderDataLoader再被封装进DataLoaderIter

Dataset

Dataset位于torch.utils.data.Dataset,每当我们自定义类MyDataset必须要继承它并实现其两个成员函数:

  • __len__()
  • __getitem__()

例如:

import torch
from torch.utils.data import Dataset
import pandas as pd

# 定义自己的类
class MyDataset(Dataset):
    
    # 初始化
    def __init__(self, file_name):
        # 读入数据
        self.data = pd.read_csv(file_name)
    
    # 返回df的长度
    def __len__(self):
        return len(self.data)
    
    # 获取第idx+1列的数据
    def __getitem__(self, idx):
        return self.data[idx].label

# 通过实例化对象来访问该类
# 假设同目录下存在名为median_benchmark.csv的文件
ds = MyDataset('median_benchmark.csv')

'''
len(ds)     返回数据总数
ds[101]     返回索引处的数据
'''

DataLoader

DataLoader位于torch.utils.data.DataLoader, 为我们提供了对Dataset读取操作

# 仅仅列举了常用的几个参数
torch.nn.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
  • dataset : 上面所实现的自定义类Dataset
  • batch_size : 默认为1,每次读取的batch的大小
  • shuffle : 默认为False, 是否对数据进行shuffle操作(简单理解成将数据集打乱)
  • num_works : 默认为0,表示在加载数据的时候每次使用子进程的数量,即简单的多线程预读数据的方法

DataLoader返回的是一个迭代器,我们通过这个迭代器来获取数据

Dataloder的目的是将给定的 n n n个数据, 经过Dataloader操作后, 在每一次调用时调用一个小batch, 如:

  • 给出的是: ( 5000 , 28 , 28 ) (5000, 28, 28) (5000,28,28), 表示有 5000 5000 5000个样本,每个样本的size为 ( 28 , 28 ) (28, 28) (28,28)
  • 经过Dataloader处理后, 一次得到的是 ( 100 , 28 , 28 ) (100, 28, 28) (100,28,28)(假设batch_size大小为100), 表示本次取出100个样本, 每个样本的size为 ( 28 , 28 ) (28,28) (28,28)
# 连接上面的Dataset实现代码

from torch.utils.data import DataLoader

dl = DataLoader(ds, batch_size=10, shuffle=True, num_works=2)

通过迭代器来分次获取数据

dl_data = iter(dl)
print(next(dl_data))

'''
Output:(示例)

tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
        24000.], dtype=torch.float64)

'''

或,直接通过for循环进行遍历输出

for i, data in enumerate(dl):
    print(i, data)
    
    # 这里只循环一次,所以用break
    break

'''
Output:(示例)

0 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
        24000.], dtype=torch.float64)
''''

参考资料

  • [1]. https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
  • [2]. https://github.com/pandadreamer/pytorch-handbook/blob/master/chapter2/2.1.4-pytorch-basics-data-lorder.ipynb
  • [3]. https://zhuanlan.zhihu.com/p/30934236

你可能感兴趣的:(Python)