【PyTorch】(二)加载数据集

文章目录

  • 通用方法

通用方法

  • 创建数据集
    主要是将数据集读入内存,并用Dataset类封装。直接继承Dataset类的自定义数据集必须要重写__getitem__方法,用于根据索引获得相应样本数据。必要时还可以重写__len__方法,用于返回数据集的大小。
  • 加载数据集
    使用DataLoader类将Dataset封装的数据集分成批次并进行迭代,以便于模型训练。DataLoader常用参数如下:
    • dataset
      要加载的数据集。
    • batch_size
      每个数据批次中包含的样本数。默认为1。
    • shuffle
      是否打乱数据集。默认为False。
    • num_workers
      使用几个进程来加载数据。默认为0,即在主进程中加载数据。
    • drop_last
      当数据集样本数不能被batch_size整除时,是否舍弃最后一个不完整的batch。默认为False。
  • 将数据转移到GPU
    1. 可以使用方法:变量.to(device)
    2. 可以使用方法:变量.cuda(0)
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class BostonHousingDataset(Dataset):
	"""定义波士顿房价数据集"""
    def __init__(self):
        self.data = np.load('../dataset/boston_housing/boston_housing.npz')

    def __getitem__(self, index):
        return self.data['x'][index], self.data['y'][index]

    def __len__(self):
        return self.data['x'].shape[0]

dataset = BostonHousingDataset()
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for X,y in dataloader:
    # 将数据转移到GPU
    X = X.to(device)
    y = y.to(device)
    # 也可以
    X = X.cuda()
    y = y.cuda()

你可能感兴趣的:(深度学习,pytorch)