PyTorch使用Dataset和DataLoader加载数据集

在PyTorch里优化器都是小批量小批量地优化训练的,即每次都会从原数据集中取出一小批量进行训练,完成一次权重更新后,再从原数据集中取下一个小批量数据,然后再训练再更新。。。
比如最常用的小批量随机梯度下降(Mini-Batch Gradient Descent,MBGD)。

毕竟原数据集往往很大,不可能一次性的全部载入内存,只能一小批一小批地载入内存。训练完了就扔了,再加载下一小批。


如何实现批量地加载数据集?

在PyTorch的torch.utils.data包中定义了两个类Dataset和DataLoader,这两个类就是用来“批量地加载数据”。

下面说一下其用法,也就是如何批量地加载数据:

写一个简单的数据加载器

import numpy as np
import torch 
# utils是工具包
from torch.utils.data import Dataset  # Dataset是个抽象类,只能用于继承
from torch.utils.data import DataLoader # DataLoader需实例化,用于加载数据


class GetData(Dataset):   # 继承Dataset类
    def __init__(self, csv_path): 
        df = pd.read_csv(csv_path, sep='\t') # 加载csv数据集文件
        
        # 把数据和标签拿出来        
        self.data = df['data'] 
        self.label = df['label']

        # 数据集的长度
        self.length = data.shape[0]
        
    # 下面两个魔术方法比较好写,直接照着这个格式写就行了 
    def __getitem__(self, index): # 参数index必写
        return self.x_data[index], self.y_data[index]
    
    def __len__(self): 
        return self.length # 只需返回一个数据集长度即可

# 实例化    
dataset = GetData('data/diabetes.csv') 
train_loader = DataLoader(dataset=dataset, # 要传递的数据集
                          batch_size=32, #一个小批量数据的大小是多少
                          shuffle=True, # 数据集顺序是否要打乱,一般是要的。测试数据集一般没必要
                          num_workers=2) # 需要几个进程来一次性读取这个小批量数据

使用方法

类似于迭代器的使用

for epoch in range(100): 
	
	# 主要看下面两行代码
    for i, data in enumerate(train_loader, 0): 
        # 1. 数据准备 
        inputs, labels = data 
        
        # 2. 前向传播 
        y_pred = model(inputs) 
        loss = criterion(y_pred, labels) 
        print(epoch, i, loss.item()) 
        # 3. 反向传播
        optimizer.zero_grad() 
        loss.backward() 
        # 4. 权重/模型更新 
        optimizer.step()

总结

模板如下:

class GetData(Dataset): 
    def __init__(self): 
        '''
        有两种写法:
        1、将全部数据都加载进内存里,适用于少量数据(上面那个例子就是全部加载);
        2、当数据量或者标签量很大时,比如图片,就把这些数据或者标签放到文件或数据库里去,只需在此方法中初始化定义这些文件索引的列表即可。
        '''
        pass
    
    # 以下2个方法都是魔法方法
    def __getitem__(self, index): # 表示将来实例化这个对象后,它能支持下标(索引)操作,也就是能通过索引把里面的数据拿出来。
        pass
    def __len__(self):  # 返回数据集条数
        pass
    
dataset = DiabetesDataset() 
train_loader = DataLoader(dataset=dataset, # 传递数据集
                          batch_size=32, #一个小批量容量是多少
                          shuffle=True, # 数据集顺序是否要打乱,一般是要的。测试数据集一般没必要
                          num_workers=2) # 需要几个进程来一次性读取这个小批量数据

你可能感兴趣的:(PyTorch,pytorch)