torch.utils.data.Dataset

文章目录

  • torch.utils.data.Dataset
    • 结构
    • 示例
      • 超分辨率数据集bsd_300
        • __ getitem __()
        • transform
      • imagenet22k数据集
        • __ getitem __()
    • Related Links

torch.utils.data.Dataset

表示一个数据集的抽象类,Map-style的数据集都应该是它的子类,并且重写__getitem__(),支持给定key值获取数据,重写__len__()以应用torch.utils.data.Samplertorch.utils.data.DataLoader的默认选项,返回数据集的尺寸。子类也可以应用__getitems__()来加速批次样本加载,这个方法接受批次样本的索引列表,返回样本的列表。torchvision中的内置数据集和数据集基类DatasetFolder, ImageFolder和VisonDataset都是他的子类,因此也可以作为自制数据集的参考。
具体可以参照
vision/torchvision/datasets at main · pytorch/vision (github.com)
Datasets — Torchvision 0.16 documentation (pytorch.org)
和之前的文章
torchvision.datasets的三个基础类
torchvision.datasets

结构

'''SOURCE'''
class Dataset(Generic[T_co]):
   
    def __getitem__(self, index) -> T_co:
        raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")


    def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
        return ConcatDataset([self, other])

示例

超分辨率数据集bsd_300

参照examples/super_resolution at main · pytorch/examples (github.com)

首先初始化判别数据集路径中的图像文件,得到图像文件路径的列表,之后图像被读取为PIL的Image对象,转换为YCbCr模式后,只选择其中的Y分量

__ getitem __()
  • 输入 加载为Image对象的图片文件
  • 输出
    • input 输入进行选定的输入变换后的图片
    • target 标签真值,在这里同样源自于输入图片,不过和上面相比进行了不同的变换
import torch.utils.data as data

from os import listdir
from os.path import join
from PIL import Image


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])


def load_img(filepath):
    img = Image.open(filepath).convert('YCbCr')
    y, _, _ = img.split()
    return y


class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, input_transform=None, target_transform=None):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]

        self.input_transform = input_transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        input = load_img(self.image_filenames[index])
        target = input.copy()
        if self.input_transform:
            input = self.input_transform(input)
        if self.target_transform:
            target = self.target_transform(target)

        return input, target

    def __len__(self):
        return len(self.image_filenames)
transform

如上面提到的,在这个任务中,输入和标签是原输入图像的变换,因此区分两者的是Input和target两种不同的变换,target即为只经过裁剪的原图像,而输入网络的图片要通过定义的上采样因子Resize至尺寸更小的图片,以引入图片数据损失,来训练超分辨率网络。

def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

def input_transform(crop_size, upscale_factor):
    return Compose([
        CenterCrop(crop_size),
        Resize(crop_size // upscale_factor),
        ToTensor(),
    ])

def target_transform(crop_size):
    return Compose([
        CenterCrop(crop_size),
        ToTensor(),
    ])

def get_training_set(upscale_factor):
    root_dir = download_bsd300()
    train_dir = join(root_dir, "train")
    crop_size = calculate_valid_crop_size(256, upscale_factor)

    return DatasetFromFolder(train_dir,
                             input_transform=input_transform(crop_size, upscale_factor),
                             target_transform=target_transform(crop_size))

def get_test_set(upscale_factor):
    root_dir = download_bsd300()
    test_dir = join(root_dir, "test")
    crop_size = calculate_valid_crop_size(256, upscale_factor)

    return DatasetFromFolder(test_dir,
                             input_transform=input_transform(crop_size, upscale_factor),
                             target_transform=target_transform(crop_size))

imagenet22k数据集

参照Swin-Transformer/data/imagenet22k_dataset.py at main · microsoft/Swin-Transformer (github.com)

首先得到图像文件路径的列表,文件名和类别序号存储在_map.txt文件中,之后图像被读取为PIL的Image对象,转换为RGB模式,

数据集的树形结构

$ tree imagenet22k/
imagenet22k/
├── ILSVRC2011fall_whole_map_train.txt
├── ILSVRC2011fall_whole_map_val.txt
└── fall11_whole
├── n00004475
├── n00005787
├── n00006024
├── n00006484
└── …

__ getitem __()
  • 输入 txt标注文件 和读取的RGB图像

  • 输出

    • images 经过变换后的RGB图像
    • target 在这里是类别标签的序号,从给定的标注txt文件中读取
def build_dataset(is_train, config):
    transform = build_transform(is_train, config)
    if config.DATA.DATASET == 'imagenet':
        prefix = 'train' if is_train else 'val'
        if config.DATA.ZIP_MODE:
            ann_file = prefix + "_map.txt"
            prefix = prefix + ".zip@/"
            dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
                                        cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
        else:
            root = os.path.join(config.DATA.DATA_PATH, prefix)
            dataset = datasets.ImageFolder(root, transform=transform)
        nb_classes = 1000
    elif config.DATA.DATASET == 'imagenet22K':
        prefix = 'ILSVRC2011fall_whole'
        if is_train:
            ann_file = prefix + "_map_train.txt"
        else:
            ann_file = prefix + "_map_val.txt"
        dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform)
        nb_classes = 21841
    else:
        raise NotImplementedError("We only support ImageNet Now.")

    return dataset, nb_classes
import os
import json
import torch.utils.data as data
import numpy as np
from PIL import Image

import warnings

warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)


class IN22KDATASET(data.Dataset):
    def __init__(self, root, ann_file='', transform=None, target_transform=None):
        super(IN22KDATASET, self).__init__()

        self.data_path = root
        self.ann_path = os.path.join(self.data_path, ann_file)
        self.transform = transform
        self.target_transform = target_transform
        # id & label: https://github.com/google-research/big_transfer/issues/7
        # total: 21843; only 21841 class have images: map 21841->9205; 21842->15027
        self.database = json.load(open(self.ann_path))

    def _load_image(self, path):
        try:
            im = Image.open(path)
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
    
        idb = self.database[index]

        # images
        images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB')
        if self.transform is not None:
            images = self.transform(images)

        # target
        target = int(idb[1])
        if self.target_transform is not None:
            target = self.target_transform(target)

        return images, target

    def __len__(self):
        return len(self.database)

Related Links

microsoft/Swin-Transformer: This is an official implementation for “Swin Transformer: Hierarchical Vision Transformer using Shifted Windows”. (github.com)

pytorch/torch/utils/data/dataset.py at main · pytorch/pytorch (github.com)

你可能感兴趣的:(pytorch,python)