PyTorch----数据加载器

什么是数据加载器?

  • 深度学习是由数据支撑起来的,所以我们一般在做深度学习的时候往往伴随着大量、复杂的数据。如果把所有的数据全部加载到内存上,容易把电脑的内存“撑爆”,所以要分批次一点点加载数据
  • 每一种深度学习的框架都有自己所规定的数据格式,数据加载器就有了必要的作用
  • 数据加载器就是把大量的数据,分批次加载和处理成框架所需要的数据格式

数据分批次加载

使用PyTorch内置的模块 torch.utils.data.DataLoader()数据加载器

参数

  1. dataset:数据集
  2. batch_size: 每一批数据的总量
  3. shuffle: True or False 为True的时候会将数据打乱再分批

PyTorch自带MNIST数据的分批

手写数字数据集

  1. 加载数据MNIST数据集在torchvision.datasets.MNIST中
import torch
import torchvision

train_dataset = torchvision.datasets.MNIST(root="./data1",train=True,transform=torchvision.transforms.ToTensor(),download=False)
  1. 取出一张图片展示
import numpy as np
import matplotlib.pyplot as plt

# 获取到第一条数据
data,label = train_dataset[0]

# 因为数据集里面的数据进行过归一化,所以要反归一化
img = np.array(data) * 255
img = img.reshape(28,28).astype(np.uint8)
# 展示
plt.imshow(img,'gray')
plt.show()

PyTorch----数据加载器_第1张图片

  1. 使用DataLoader方法分批次
from torch.utils.data import DataLoader
# 创建DataLoader对象
train_loader = DataLoader(dataset=train_dataset,batch_size=100,shuffle=True)

num_epochs = 1
for epoch in range(num_epochs):
	# 第二层循环会每次打开一批次的数据 当前一批次为100
    for i,(inputs,labels) in enumerate(train_loader):
        print(f'Epoch: {epoch+1}/{num_epochs},Step {i+1}/{len(train_dataset)/100}| Inputs {inputs.shape} | Labels {labels.shape}')
        # 当前inputs和labels里面有100条数据
        print(labels)
        break
print(len(train_loader))

PyTorch----数据加载器_第2张图片


自定义Dataset类

  • DataLoader()的dataset参数必须继承于PyTorch的Dataset类
  • 只有继承了PyTorch中的Dataset接口的类,才能够被传入DataLoader中

自定义一个Dataset类,让PyTorch去认识我们的数据

步骤:

  • 创建一个类继承Dataset

  • __init__魔法方法内读取数据

    • 获取到数据的长度
    • 获取到特征数据和输出标签
  • __getitem__方法内返回第index条数据

  • __len__方法内返回数据的长度

from torch.utils.data import Dataset
class WineDataset(Dataset):
    def __init__(self):
        # 读取csv数据
        xy = pd.read_csv("./wine.csv")
        # 获取到数据的长度
        self.n_samples = xy.shape[0]
        # 特征数据
        self.x_data = torch.from_numpy(xy.values[:,1:])
        # 输出标签
        self.y_data = torch.from_numpy(xy.values[:,0])
        
    def __getitem__(self,index):
        # 遍历的时候返回数据  可迭代对象
        return self.x_data[index],self.y_data[index]
    
    def __len__(self.n_samples):
        # 返回数据长度
        return self.n_sampleso

查看自定义的Dataset类:

# 使用DataLoader去加载数据集合
from torch.utils.data import DataLoader
import torch
wineData = WineDataset()
# 传入加载器
train_loader = DataLoader(dataset=wineData,batch_size=4,shuffle=True)
# 分批训练
# 迭代次数
epoch_num = 5
total_samples = len(wineData)
print("total_samples:",total_samples)
# 开始训练
for epoch in range(epoch_num):
    for i,(inputs,labels) in enumerate(train_loader):
        print(i,labels)

每次批次加载4条数据

PyTorch----数据加载器_第3张图片

因为数据分为4批次,是有余数的,所以最后一行数据不是4条:
在这里插入图片描述

你可能感兴趣的:(机器学习,pytorch,机器学习,深度学习)