pytorch 实现自己的ImageFolder(可以在分类任务中加载图像对)

pytorch 实现可以加载图像和对应mask的ImageFolder

目的

当在分类任务中训练数据集中样本不仅包含有原始图像,还有对应的mask时,例如需要同时加载原始图像和语义分割结果。pytorch自带的数据加载函数无法满足这种需求,而我们不想进行大的改动时,采用以下方式可以实现该功能。

代码实现

我们在pytorch自带的torchvision.datasets.ImageFolder文件基础上,通过对部分地方进行改动来实现加载图像对。

  1. 首先可以将原始的ImageFolder文件代码复制到新的py文件中;
  2. 对DatasetFolder函数进行改动,代码如下所示;
class DatasetFolder(VisionDataset):
   """A generic data loader where the samples are arranged in this way: ::

       root/class_x/xxx.ext
       root/class_x/xxy.ext
       root/class_x/[...]/xxz.ext

       root/class_y/123.ext
       root/class_y/nsdf3.ext
       root/class_y/[...]/asd932_.ext

   Args:
       root (string): Root directory path.
       loader (callable): A function to load a sample given its path.
       extensions (tuple[string]): A list of allowed extensions.
           both extensions and is_valid_file should not be passed.
       transform (callable, optional): A function/transform that takes in
           a sample and returns a transformed version.
           E.g, ``transforms.RandomCrop`` for images.
       target_transform (callable, optional): A function/transform that takes
           in the target and transforms it.
       is_valid_file (callable, optional): A function that takes path of a file
           and check if the file is a valid file (used to check of corrupt files)
           both extensions and is_valid_file should not be passed.

    Attributes:
       classes (list): List of the class names sorted alphabetically.
       class_to_idx (dict): Dict with items (class_name, class_index).
       samples (list): List of (sample path, class_index) tuples
       targets (list): The class_index value for each image in the dataset
   """

   def __init__(
           self,
           root: str,
           loader: Callable[[str], Any],
           mask_loader: Callable[[str], Any],
           extensions: Optional[Tuple[str, ...]] = None,
           transform: Optional[Callable] = None,
           target_transform: Optional[Callable] = None,
           is_valid_file: Optional[Callable[[str], bool]] = None,
   ) -> None:
       super(DatasetFolder, self).__init__(root, transform=transform,
                                           target_transform=target_transform)
       self.mask_root = os.path.join(self.root, 'mask')
       self.image_root = os.path.join(self.root, 'image')
       print(self.mask_root)
       print(self.image_root)
       classes, class_to_idx = self._find_classes(self.image_root)

       image_samples = self.make_dataset(self.image_root, class_to_idx, extensions, is_valid_file)
       mask_samples = self.make_dataset(self.mask_root, class_to_idx, extensions, is_valid_file)

       # samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
       if len(image_samples) == 0:
           msg = "Found 0 files in subfolders of: {}\n".format(self.image_root)
           if extensions is not None:
               msg += "Supported extensions are: {}".format(",".join(extensions))
           raise RuntimeError(msg)

       self.loader = loader
       self.mask_loader = mask_loader
       self.extensions = extensions

       self.classes = classes
       self.class_to_idx = class_to_idx
       # self.samples = samples
       self.image_samples = image_samples
       self.mask_samples = mask_samples
       self.targets = [s[1] for s in image_samples]

   @staticmethod
   def make_dataset(
       directory: str,
       class_to_idx: Dict[str, int],
       extensions: Optional[Tuple[str, ...]] = None,
       is_valid_file: Optional[Callable[[str], bool]] = None,
   ) -> List[Tuple[str, int]]:
       return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

   def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
       """
       Finds the class folders in a dataset.

       Args:
           dir (string): Root directory path.

       Returns:
           tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.

       Ensures:
           No class is a subdirectory of another.
       """
       classes = [d.name for d in os.scandir(dir) if d.is_dir()]
       classes.sort()
       class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
       return classes, class_to_idx

   def __getitem__(self, index: int) -> Tuple[Any, Any]:
       """
       Args:
           index (int): Index

       Returns:
           tuple: (sample, target) where target is class_index of the target class.
       """
       # path, target = self.samples[index]
       image_path, target = self.image_samples[index]
       mask_path, target = self.mask_samples[index]
       # sample = self.loader(path)
       image_sample = self.loader(image_path)
       mask_sample = self.mask_loader(mask_path)
       if self.transform is not None:
           sample = self.transform(image_sample, mask_sample)
       if self.target_transform is not None:
           target = self.target_transform(target)

       return sample, target

   def __len__(self) -> int:
       return len(self.image_samples)```

  1. 对ImageFolder函数进行改动,代码如下:
class ImageFolder(DatasetFolder):
    """A generic data loader where the images are arranged in this way: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/[...]/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/[...]/asd932_.png

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
        mask_loader (callable, optional): A function to load an image given its path.
        is_valid_file (callable, optional): A function that takes path of an Image file
            and check if the file is a valid file (used to check of corrupt files)

     Attributes:
        classes (list): List of the class names sorted alphabetically.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """

    def __init__(
            self,
            root: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            loader: Callable[[str], Any] = default_loader,
            mask_loader: Callable[[str], Any] = png_loader,
            is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
        super(ImageFolder, self).__init__(root, loader, mask_loader, IMG_EXTENSIONS if is_valid_file is None else None,
                                          transform=transform,
                                          target_transform=target_transform,
                                          is_valid_file=is_valid_file)
        self.imgs = self.image_samples
  1. 再额外的添加一个png_loader函数,代码如下:
def png_loader(path: str) -> Image.Image:
   # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
   with open(path, 'rb') as f:
       img = Image.open(f)
       return img.convert('P')
  1. 对transforms函数进行修改,使得其能够同时对图像和mask同步处理,详细代码实现参考:
    transforms同步处理图像和mask代码实现

注:上述代码在pytorch1.8.0版本调试通过

你可能感兴趣的:(pytorch,分类,深度学习)