目录
继承关系
初始化方法:
一:find_classes
二:make_dataset
三: 写一个验证函数
四:loader
五:
六: __getitem__:
总结:
class ImageFolder(DatasetFolder):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super().__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
啥事没干
class DatasetFolder(VisionDataset):
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
@staticmethod
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
return find_classes(directory)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)
classes, class_to_idx = self.find_classes(self.root)
具体实现主要是:
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
打断点进去看一下:
总的来说,就是根据路径:得到,文件名,数字索引。当然它将文件名表示为类别。
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
参数说明: root: 文件地址
class_to_idx:类别索引
extensions:图片后缀 ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
is_valid_file: 是一个可调用的函数: Optional[Callable[[str], bool]]
返回:图片路径,和类别(索引)
if is_valid_file(path):
item = path, class_index
instances.append(item)
如果验证通过则会加到返回中,反之不会。
class Check:
def __init__(self, key):
print('看看值')
print(key)
def __call__(self, *args, **kwargs):
return True
使用:直接传进去
class TestDataset(torchvision.datasets.ImageFolder):
# 根路径,
def __init__(self, root, imgsz, cache, augment):
super().__init__(root=root, is_valid_file=Check)
结果:
传的是一个图片地址,可以拿到图片做一些验证工作。
self.loader = loader
也是一个回调函数
loader: Callable[[str], Any],
默认提供的是:
def default_loader(path: str) -> Any:
from torchvision import get_image_backend
if get_image_backend() == "accimage":
return accimage_loader(path)
else:
return pil_loader(path)
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, "rb") as f:
img = Image.open(f)
return img.convert("RGB")
一个读取图片的方法而已。
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
我们其实只关心samples 图片和targers 标签
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
如果要重写,我觉得主要就是getitem方法。
首先:类给主要给我们提供了,文件读取的方法。我们可以直接拿到文件路径集合。
有了文件路径,就没必要用它的文件加载方法。yolo中有更高效的方法,如果是目标检测,我们可以重新指定标签,比如yolo中规定标签和图片名一样,便于找到。