torchvision.dataset中为自定义数据集提供的三个基础类DatasetFolder, ImageFolder和VisonDataset, 这三者除了均为torch.utils.data.Dataset()的子类外,它们之间也存在继承关系。其中VisionDataset定义于datasets/vision.py,DatasetFolder和ImageFolder定义于dataset/folder.py。VisionDataset没有默认的__getitem__ 和__len__方法,DatasetFolder继承自VisionDataset,重写了了__getitem__ 和__len__方法,ImageFolder又继承自DatasetFolder。
vision/torchvision/datasets at main · pytorch/vision (github.com)
Datasets — Torchvision 0.16 documentation (pytorch.org)
torchvision.datasets.CIFAR10继承自VisionDataset,重写了__getitem__ 和__len__方法,并且定义了_load_meta方法以实现类似find_classes方法的功能。
# 部分代码 具体代码参照https://pytorch.org/vision/0.16/_modules/torchvision/datasets/cifar.html#CIFAR10
class CIFAR10(VisionDataset):
def _load_meta(self) -> None:
path = os.path.join(self.root, self.base_folder, self.meta["filename"])
if not check_integrity(path, self.meta["md5"]):
raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it")
with open(path, "rb") as infile:
data = pickle.load(infile, encoding="latin1")
self.classes = data[self.meta["key"]]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
def __getitem__(self, index: int) -> Tuple[Any, Any]:
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
torchvision.datasets.VOCSegmentation继承了_VOCBase,_VOCBase又继承自VisionDataset
class VOCSegmentation(_VOCBase):
_SPLITS_DIR = "Segmentation"
_TARGET_DIR = "SegmentationClass"
_TARGET_FILE_EXT = ".png"
@property
def masks(self) -> List[str]:
return self.targets
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert("RGB")
target = Image.open(self.masks[index])
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
继承自torch.utils.data.Dataset() 依然需要重写__getitem__ 和__len__方法
参数
root 数据集的根地址,仅用于重写__repr__
transforms 应用在一张图像和标签的变换,并且返回两者的变换版本
transform 应用在图像上的变换,返回变换后的版本
target_transform 应用在标签上的变换,返回变换后的版本
import os
import torch.utils.data as data
class VisionDataset(data.Dataset):
_repr_indent = 4
def __init__(
self,
root: str = None, # type: ignore[assignment]
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
_log_api_usage_once(self)
if isinstance(root, str):
root = os.path.expanduser(root)
self.root = root
has_transforms = transforms is not None
has_separate_transform = transform is not None or target_transform is not None
if has_transforms and has_separate_transform:
raise ValueError("Only transforms or transform/target_transform can be passed as argument")
# for backwards-compatibility
self.transform = transform
self.target_transform = target_transform
if has_separate_transform:
transforms = StandardTransform(transform, target_transform)
self.transforms = transforms
def __getitem__(self, index: int) -> Any:
raise NotImplementedError
def __len__(self) -> int:
raise NotImplementedError
def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = [f"Number of datapoints: {self.__len__()}"]
if self.root is not None:
body.append(f"Root location: {self.root}")
body += self.extra_repr().splitlines()
if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body]
return "\n".join(lines)
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
def extra_repr(self) -> str:
return ""
在这个文件中有DatasetFolder类,和DatasetFolder类默认调用的函数find_calsses() make_dataset()
DatasetFolder中find_classes方法默认调用的函数,找到一个如下结构存储的数据集中的类别目录
directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── …
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── …
└── asd932_.ext
参数:
返回:
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
参数
返回
def make_dataset(
directory: str,
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
raise FileNotFoundError(msg)
return instances
检查一个文件是否是允许的拓展名
参数
def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
检查一个文件是否是允许的图片扩展名
参数:
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
def is_image_file(filename: str) -> bool:
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
位于torchvision.datasets.folder,继承自VisionDataset,通用的数据loader 目录结构可以通过重写find_classed方法自定义
参数
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
参数
directory 数据集的根目录,对应self.root
class_to_idx 类名和类序号的映射字典
extensions 允许的扩展名列表
is_valid_file 验证文件有效性的函数
class_to_idx参数不能为None,因为make_dataset()需要使用类内的find_classes方法,如果为None则class_to_idx会默认使用类外的find_classes函数,而类内方法是类外函数的重写,因此两者可能因为overridden而产生逻辑的不同
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
if class_to_idx is None:
raise ValueError("The class_to_idx parameter cannot be None.")
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
默认返回torchvision.datasets.folder中的find_classes函数,通过重写来对应不同数据集结构
参数
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
return find_classes(directory)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)
继承自 DatasetFolder 一个通用的数据loader 数据集默认按以下结构排列:
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[…]/asd932_.png
参数:
class ImageFolder(DatasetFolder):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super().__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
folder.py提供了三种loader,分别是 default_loader, pil_loder和pil_loader,分别使用了两种读取图像的库accimage和PIL,对于PIL的简单介绍,也可以参照PIL Image 模块
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, "rb") as f:
img = Image.open(f)
return img.convert("RGB")
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
import accimage
try:
return accimage.Image(path)
except OSError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path: str) -> Any:
from torchvision import get_image_backend
if get_image_backend() == "accimage":
return accimage_loader(path)
else:
return pil_loader(path)