PyTorch学习笔记(二)读取数据

Environment

  • OS: macOS Mojave
  • Python version: 3.7
  • PyTorch version: 1.4.0
  • IDE: PyCharm

文章目录

  • 0. 写在前面
  • 1. 构造 Dataset
    • 1.1 法一:定义 Dataset 子类
    • 1.2 法二:使用 ImageFolder 类
  • 2. 构造 DataLoader
    • 2.1 可能需要定义采样器解决类别不平衡问题
    • 2.2 构造 DataLoader 实例对象


0. 写在前面

本文记录一下使用 PyTorch 读取图像数据。数据按照特定的目录结构放好之后,需要构建 Dataset 和 DataLoader 对数据进行读取。

Dataset 定义了读取数据的位置(data_dir)和方式(__getitem__),而 DataLoader 给出 indices 决定让 Dataset 读取哪些数据。

1. 构造 Dataset

以划分好的 TinyMind人民币面值识别 任务的训练集为例,目录结构如下

├── rmbdata
│   ├── categorise.py
│   ├── split.py
│   ├── test
│   │   ├── images.jpg
│   ├── train
│   │   ├── images.jpg
│   ├── train
│   │   ├── images.jpg
│   ├── train_face_value_label.csv
│   ├── val
│   │   ├── images.jpg

1.1 法一:定义 Dataset 子类

在 data 模块中定义一个 torch.utils.data.Dataset 的子类,对 __getitem____len__ 方法进行重写,在训练或评估中导入使用

├── data.py
└── rmbdata
import os
from PIL import Image
from torch.utils.data import Dataset

class RMBFaceValueDataset(Dataset):
    """Dataset for classifying RMB by the face value"""
    def __init__(self, data_dir, transform=None, class_to_idx=None):
        """
        Params:
        data_dir: str
            directory of data
        transform: torchvision.transform
            data transform approaches
        idx_to_class: dict
            map class name to a specified number
        """
        self.data_dir = data_dir
        self.transform = transform
        self.class_to_idx = class_to_idx
        self.data_info = self._get_img_info()

    def __getitem__(self, index):
        """
        Receive an index and return an example with its label

        Returns:
        image: torch.Tensor or PIL.Image
            image data
        label: int
            label for this example
        """
        image_path, label = self.data_info[index]
        image = Image.open(image_path).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def __len__(self):
        return len(self.data_info)

    def _get_img_info(self):
        data_info = []
        for root, dirs, _ in os.walk(self.data_dir):
            for sub_dir in dirs:  # iter all categories
                image_names = os.listdir(os.path.join(root, sub_dir))
                image_names = list(filter(lambda x: x.endswith('.jpg'), image_names))

                for image_name in image_names:  # iter all images of this category
                    image_path = os.path.join(root, sub_dir, image_name)
                    label = self.class_to_idx[sub_dir]
                    data_info.append((image_path, int(label)))

        return data_info


if __name__ == '__main__':
    # The codes in 'train.py'... :
    # from data import RMBFaceValueDataset
    from torch.utils.data import DataLoader
    from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor

    data_transforms = {
        'train': Compose([Resize(256), CenterCrop(224), ToTensor()]),
        'val': Compose([Resize(256), CenterCrop(224), ToTensor()]),
        'test': Compose([Resize(256), CenterCrop(224), ToTensor()])
    }

    data_dir = os.path.join(os.curdir, 'rmbdata')
    class_to_idx = {
        '0.1': 0, '0.2': 1, '0.5': 2,
        '1.0': 3, '2.0': 4, '5.0': 5,
        '10.0': 6, '50.0': 7, '100.0': 8
    }
    image_datasets = {x: RMBFaceValueDataset(
        os.path.join(data_dir, x), transform=data_transforms[x], class_to_idx=class_to_idx)
        for x in ['train', 'val', 'test']
    }

    # get the sizes of datasets for further calculation of training indices
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}

    # check
    print(dataset_sizes)
# {'train': 39227, 'val': 393, 'test': 20000}

1.2 法二:使用 ImageFolder 类

这样就不需要搞一个 data 模块,直接在 train.py 中

import os
from torchvision.datasets import ImageFolder

# get dataset, 传入 transform
data_dir = os.path.join(os.curdir, 'rmbdata')
image_datasets = {x: ImageFolder(os.path.join(data_dir, x), transform=data_transforms[x])
                  for x in ['train', 'val', 'test']}

# 得到数据集 dataset 的大小,以便在训练中计算评价指标
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}

print(dataset_sizes)
# {'train': 39227, 'val': 393, 'test': 20000}
print(image_datasets['train'].class_to_idx)
# {'0.1': 0, '0.2': 1, '0.5': 2, '1.0': 3, '10.0': 4, '100.0': 5, '2.0': 6, '5.0': 7, '50.0': 8}

2. 构造 DataLoader

训练或评估中,在得到 Dataset 之后,使用 torch.utils.data.DataLoader 类得到 DataLoader

2.1 可能需要定义采样器解决类别不平衡问题

参考Balanced Sampling between classes with torchvision DataLoader

def make_weights_for_unbalanced_classes(image_path_label, n_classes):
	"""
    生成每个样本实例的权重,为后续定义 sampler,以解决类别不平衡问题
    
    Params:
    image_path_label: list of tuples of str and int
        一个列表,其中元素为元组,元组内容为每个 image 的 path 和 label
    n_classes: int
        总类别数
        
    Returns:
    weights: list of floats
        每个样本实例的权重
    """
    
    count = [0] * n_classes  # 用于保存每一类的计数结果
    for _, label in image_path_label:  # 对每一类的样本容量进行计数
        count[label] += 1

    weight_per_class = [0.] * n_classes  # 用于保存每一类的权重
    N = float(sum(count))  # 总样本容量
    for i in range(n_classes):
        weight_per_class[i] = N / float(count[i])  # 计算每一类样本的权重

    weights = [0] * len(image_path_label)  # 用于保存每一个样本的权重
    for i, (_, label) in enumerate(image_path_label):
        weights[i] = weight_per_class[label]

    return weights


weights = make_weights_for_unbalanced_classes(train_ds.imgs, len(train_ds.classes))
weights = torch.DoubleTensor(weights)
sampler = WeightedRandomSampler(weights, len(weights))

2.2 构造 DataLoader 实例对象

from torch.utils.data import DataLoader

# 可以传入定义好的 sampler
dataloaders = {x: DataLoader(
    image_datasets[x],  # Dataset 子类
    batch_size=128,
    num_workers=4,  # 是否使用多进程读取数据(4/8/16)。加速数据读取,加快模型训练
    sampler=None,
    shuffle=True,  # 每个 epoch 是否打乱数据排序。当使用 sampler 时,shuffle 不能为 True
    drop_last=False  # 当样本数不能被 batch_size 整除时,是否舍弃最后一批数据
) for x in ['train', 'val', 'test']}

# check
print(len(dataloaders['train']), len(dataloaders['val']), len(dataloaders['test']))
# 307 4 157

# Get a batch of training data
inputs, labels = next(iter(dataloaders['train']))
print(inputs.size(), len(labels))
# torch.Size([128, 3, 224, 224]) 128


你可能感兴趣的:(PyTorch学习笔记)