PyTorch加载自己的数据集

PyTorch加载自己的数据集

    • 0.引言
    • 1.定义自己的Dataset类

0.引言

  • 主要内容来源。

PyTorch提供了几种方法来加载自己的数据集。下面是一些常用的方法:

  • 1.使用torch.utils.data.Dataset类创建自定义数据集

这是一种常见的方式,用于自定义数据集。创建一个类,继承自torch.utils.data.Dataset,并重写__len__()__getitem__()方法。__len__()方法应该返回数据集的大小,__getitem__()方法应该返回一个样本。例如,以下是一个自定义数据集类的示例:

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]
        return x, y
  • 2.使用torch.utils.data.DataLoader类加载数据集

torch.utils.data.DataLoader类用于加载数据集。它可以自动对数据集进行批处理、打乱和多线程加载。下面是一个使用DataLoader加载数据集的示例:

from torch.utils.data import DataLoader

dataset = MyDataset(data, targets)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
  • 3.使用torchvision.datasets模块加载常见数据集
    torchvision.datasets模块提供了一些常见的数据集,例如MNIST、CIFAR等。可以使用这些数据集来测试模型或学习如何加载数据集。以下是一个使用torchvision.datasets加载MNIST数据集的示例:
import torchvision.datasets as datasets
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

trainset = datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True, num_workers=2)

上面的代码将下载MNIST数据集,并使用ToTensor()和Normalize()转换图像。然后使用DataLoader加载数据集。

1.定义自己的Dataset类

创建一个类,继承自torch.utils.data.Dataset,并重写__len__()__getitem__()方法:

  • __init__ 用于向类中传入外部参数,同时定义样本集
  • __len__()方法应该返回数据集的大小
  • __getitem__()方法应该返回一个样本

例如,以下是一个自定义数据集类的示例:

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]
        return x, y

这里另外给出一个例子,其中图片存放在一个文件夹,另外有一个csv文件给出了图片名称对应的标签。这种情况下需要自己来定义Dataset类:

class MyDataset(Dataset):
    def __init__(self, data_dir, info_csv, image_list, transform=None):
        """
        Args:
            data_dir: path to image directory.
            info_csv: path to the csv file containing image indexes
                with corresponding labels.
            image_list: path to the txt file contains image names to training/validation set
            transform: optional transform to be applied on a sample.
        """
        label_info = pd.read_csv(info_csv)
        image_file = open(image_list).readlines()
        self.data_dir = data_dir
        self.image_file = image_file
        self.label_info = label_info
        self.transform = transform

    def __getitem__(self, index):
        """
        Args:
            index: the index of item
        Returns:
            image and its labels
        """
        image_name = self.image_file[index].strip('\n')
        raw_label = self.label_info.loc[self.label_info['Image_index'] == image_name]
        label = raw_label.iloc[:,0]
        image_name = os.path.join(self.data_dir, image_name)
        image = Image.open(image_name).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.image_file)

构建好Dataset后,就可以使用DataLoader来按批次读入数据了,实现代码如下:

from torch.utils.data import DataLoader

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=4, shuffle=False)

其中:

  • batch_size:样本是按“批”读入的,batch_size就是每次读入的样本数

  • num_workers:有多少个进程用于读取数据,Windows下该参数设置为0,Linux下常见的为4或者8,根据自己的电脑配置来设置

  • shuffle:是否将读入的数据打乱,一般在训练集中设置为True,验证集中设置为False

  • drop_last:对于样本最后一部分没有达到批次数的样本,使其不再参与训练

这里可以看一下加载的数据。PyTorch中的DataLoader的读取可以使用next和iter来完成

import matplotlib.pyplot as plt
images, labels = next(iter(val_loader))
print(images.shape)
plt.imshow(images[0].transpose(1,2,0))
plt.show()

你可能感兴趣的:(深度学习,pytorch,深度学习,python,PyTorch加载数据集)