表示一个数据集的抽象类,Map-style的数据集都应该是它的子类,并且重写__getitem__(),支持给定key值获取数据,重写__len__()以应用torch.utils.data.Sampler和torch.utils.data.DataLoader的默认选项,返回数据集的尺寸。子类也可以应用__getitems__()来加速批次样本加载,这个方法接受批次样本的索引列表,返回样本的列表。torchvision中的内置数据集和数据集基类DatasetFolder, ImageFolder和VisonDataset都是他的子类,因此也可以作为自制数据集的参考。
具体可以参照
vision/torchvision/datasets at main · pytorch/vision (github.com)
Datasets — Torchvision 0.16 documentation (pytorch.org)
和之前的文章
torchvision.datasets的三个基础类
torchvision.datasets
'''SOURCE'''
class Dataset(Generic[T_co]):
def __getitem__(self, index) -> T_co:
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
return ConcatDataset([self, other])
参照examples/super_resolution at main · pytorch/examples (github.com)
首先初始化判别数据集路径中的图像文件,得到图像文件路径的列表,之后图像被读取为PIL的Image对象,转换为YCbCr模式后,只选择其中的Y分量
import torch.utils.data as data
from os import listdir
from os.path import join
from PIL import Image
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
def load_img(filepath):
img = Image.open(filepath).convert('YCbCr')
y, _, _ = img.split()
return y
class DatasetFromFolder(data.Dataset):
def __init__(self, image_dir, input_transform=None, target_transform=None):
super(DatasetFromFolder, self).__init__()
self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]
self.input_transform = input_transform
self.target_transform = target_transform
def __getitem__(self, index):
input = load_img(self.image_filenames[index])
target = input.copy()
if self.input_transform:
input = self.input_transform(input)
if self.target_transform:
target = self.target_transform(target)
return input, target
def __len__(self):
return len(self.image_filenames)
如上面提到的,在这个任务中,输入和标签是原输入图像的变换,因此区分两者的是Input和target两种不同的变换,target即为只经过裁剪的原图像,而输入网络的图片要通过定义的上采样因子Resize至尺寸更小的图片,以引入图片数据损失,来训练超分辨率网络。
def calculate_valid_crop_size(crop_size, upscale_factor):
return crop_size - (crop_size % upscale_factor)
def input_transform(crop_size, upscale_factor):
return Compose([
CenterCrop(crop_size),
Resize(crop_size // upscale_factor),
ToTensor(),
])
def target_transform(crop_size):
return Compose([
CenterCrop(crop_size),
ToTensor(),
])
def get_training_set(upscale_factor):
root_dir = download_bsd300()
train_dir = join(root_dir, "train")
crop_size = calculate_valid_crop_size(256, upscale_factor)
return DatasetFromFolder(train_dir,
input_transform=input_transform(crop_size, upscale_factor),
target_transform=target_transform(crop_size))
def get_test_set(upscale_factor):
root_dir = download_bsd300()
test_dir = join(root_dir, "test")
crop_size = calculate_valid_crop_size(256, upscale_factor)
return DatasetFromFolder(test_dir,
input_transform=input_transform(crop_size, upscale_factor),
target_transform=target_transform(crop_size))
参照Swin-Transformer/data/imagenet22k_dataset.py at main · microsoft/Swin-Transformer (github.com)
首先得到图像文件路径的列表,文件名和类别序号存储在_map.txt文件中,之后图像被读取为PIL的Image对象,转换为RGB模式,
数据集的树形结构
$ tree imagenet22k/
imagenet22k/
├── ILSVRC2011fall_whole_map_train.txt
├── ILSVRC2011fall_whole_map_val.txt
└── fall11_whole
├── n00004475
├── n00005787
├── n00006024
├── n00006484
└── …
输入 txt标注文件 和读取的RGB图像
输出
def build_dataset(is_train, config):
transform = build_transform(is_train, config)
if config.DATA.DATASET == 'imagenet':
prefix = 'train' if is_train else 'val'
if config.DATA.ZIP_MODE:
ann_file = prefix + "_map.txt"
prefix = prefix + ".zip@/"
dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
else:
root = os.path.join(config.DATA.DATA_PATH, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
elif config.DATA.DATASET == 'imagenet22K':
prefix = 'ILSVRC2011fall_whole'
if is_train:
ann_file = prefix + "_map_train.txt"
else:
ann_file = prefix + "_map_val.txt"
dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform)
nb_classes = 21841
else:
raise NotImplementedError("We only support ImageNet Now.")
return dataset, nb_classes
import os
import json
import torch.utils.data as data
import numpy as np
from PIL import Image
import warnings
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
class IN22KDATASET(data.Dataset):
def __init__(self, root, ann_file='', transform=None, target_transform=None):
super(IN22KDATASET, self).__init__()
self.data_path = root
self.ann_path = os.path.join(self.data_path, ann_file)
self.transform = transform
self.target_transform = target_transform
# id & label: https://github.com/google-research/big_transfer/issues/7
# total: 21843; only 21841 class have images: map 21841->9205; 21842->15027
self.database = json.load(open(self.ann_path))
def _load_image(self, path):
try:
im = Image.open(path)
except:
print("ERROR IMG LOADED: ", path)
random_img = np.random.rand(224, 224, 3) * 255
im = Image.fromarray(np.uint8(random_img))
return im
def __getitem__(self, index):
idb = self.database[index]
# images
images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB')
if self.transform is not None:
images = self.transform(images)
# target
target = int(idb[1])
if self.target_transform is not None:
target = self.target_transform(target)
return images, target
def __len__(self):
return len(self.database)
microsoft/Swin-Transformer: This is an official implementation for “Swin Transformer: Hierarchical Vision Transformer using Shifted Windows”. (github.com)
pytorch/torch/utils/data/dataset.py at main · pytorch/pytorch (github.com)