PyTorch:数据读取1 - Datasets

-柚子皮-

什么是Datasets?

在输入流水线中,准备数据的代码是这么写的

data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)

datasets.CIFAR10就是一个Datasets子类,data是这个类的一个实例。

为什么要定义Datasets?

PyTorch提供了一个工具函数torch.utils.data.DataLoader。通过这个类,我们可以让数据变成mini-batch,且在准备mini-batch的时候可以多线程并行处理,这样可以加快准备数据的速度。

Datasets就是构建这个类的实例的参数之一。

DataLoader的使用参考[]。

-柚子皮-

 

自定义Datasets

框架

dataset必须继承自torch.utils.data.Dataset。内部要实现两个函数:一个是__lent__用来获取整个数据集的大小,一个是__getitem__用来从数据集中得到一个数据片段item

import torch.utils.data as data
class CustomDataset(data.Dataset):  # 继承data.Dataset
    """Custom data.Dataset compatible with data.DataLoader."""

    def __init__(self, filename, data_info, oth_params):
        """Reads source and target sequences from txt files."""
        # # # Initialize file path or list of file names.
        self.file = open(filename, 'r')
        pass
        # # # 或者从外部数据结构data_info中读取数据
        self.all_texts = data_info['all_texts']
        self.all_labels = data_info['all_labels']
        self.vocab = data_info['vocab']

    def __getitem__(self, index):
        """Returns one data pair (source and target)."""
        # # # 从文件读取
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform或者word2id什么的).
        # 3. Return a data pair(source and target) (e.g. image and label).
        pass
        # # # 或者直接读取
        item_info = {
            "text": self.all_texts[index],
            "label": self.all_labels[index]
        }
        return item_info

    def __len__(self):
        # You should change 0 to the total size of your dataset.
        # return 0
        return len(self.all_texts)

小示例

class Dataset(torch.utils.data.Dataset):
    def __init__(self, filepath=None,dataLen=None):
        self.file = filepath
        self.dataLen = dataLen
        
    def __getitem__(self, index):
        A,B,path,hop= linecache.getline(self.file, index+1).split('\t')
        return A,B,path.split(' '),int(hop)

    def __len__(self):
        return self.dataLen

官方MNIST的例子

(代码被缩减,只留下了重要的部分):

class MNIST(data.Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))

    def __getitem__(self, index):
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        if self.train:
            return 60000
        else:
            return 10000

from: -柚子皮-

ref: [pytorch学习笔记(六):自定义Datasets]

 

你可能感兴趣的:(Pytorch,pytorch,datasets)