Pytorch 中比较重要的是对数据的处理,其中,进行数据读取的一般有三个类:
其中,这三是一个依次封装的关系:“Dataset
被封装进DataLoader
,DataLoader
再被封装进DataLoaderIter
”
Dataset
位于torch.utils.data.Dataset
,每当我们自定义类MyDataset
必须要继承它并实现其两个成员函数:
__len__()
__getitem__()
例如:
import torch
from torch.utils.data import Dataset
import pandas as pd
# 定义自己的类
class MyDataset(Dataset):
# 初始化
def __init__(self, file_name):
# 读入数据
self.data = pd.read_csv(file_name)
# 返回df的长度
def __len__(self):
return len(self.data)
# 获取第idx+1列的数据
def __getitem__(self, idx):
return self.data[idx].label
# 通过实例化对象来访问该类
# 假设同目录下存在名为median_benchmark.csv的文件
ds = MyDataset('median_benchmark.csv')
'''
len(ds) 返回数据总数
ds[101] 返回索引处的数据
'''
DataLoader
位于torch.utils.data.DataLoader
, 为我们提供了对Dataset
的读取操作
# 仅仅列举了常用的几个参数
torch.nn.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
dataset
: 上面所实现的自定义类Dataset
batch_size
: 默认为1,每次读取的batch的大小shuffle
: 默认为False, 是否对数据进行shuffle操作(简单理解成将数据集打乱)num_works
: 默认为0,表示在加载数据的时候每次使用子进程的数量,即简单的多线程预读数据的方法DataLoader
返回的是一个迭代器,我们通过这个迭代器来获取数据
Dataloder
的目的是将给定的 n n n个数据, 经过Dataloader
操作后, 在每一次调用时调用一个小batch, 如:
Dataloader
处理后, 一次得到的是 ( 100 , 28 , 28 ) (100, 28, 28) (100,28,28)(假设batch_size大小为100), 表示本次取出100个样本, 每个样本的size为 ( 28 , 28 ) (28,28) (28,28)# 连接上面的Dataset实现代码
from torch.utils.data import DataLoader
dl = DataLoader(ds, batch_size=10, shuffle=True, num_works=2)
通过迭代器来分次获取数据:
dl_data = iter(dl)
print(next(dl_data))
'''
Output:(示例)
tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
24000.], dtype=torch.float64)
'''
或,直接通过for循环进行遍历输出
for i, data in enumerate(dl):
print(i, data)
# 这里只循环一次,所以用break
break
'''
Output:(示例)
0 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
24000.], dtype=torch.float64)
''''
参考资料