重写yolo数据加载模块

原因

        yolo原本的Dataload可阅读性太差,可拓展性也很差。另一个为了熟悉源代码。

定义一个类:

class DataFolder(VisionDataset):

类的介绍:

    """
    Base Class For making datasets which are compatible with torchvision.
    It is necessary to override the ``__getitem__`` and ``__len__`` method.

    Args:
        root (string): Root directory of dataset.
        transforms (callable, optional): A function/transforms that takes in
            an image and a label and returns the transformed versions of both.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.

    .. note::

        :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
    """

具体的可以看看,我另外一篇关于这个类的源码阅读。这个类我觉得相当于定义了一种编码规范

实现

        

class DataFolder(VisionDataset):
    def __init__(
            self,
            root: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            cache=False,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        self.im_files, self.label_files = find_img_labels(self.root)
        self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in zip(self.im_files, self.label_files)]
        self.index = range(len(self.label_files))
        self.cache_ram = cache is True or cache == 'ram'
        self.cache_disk = cache == 'disk'

    def _load_image(self, id: int):
        f, t, fn, im = self.samples[id]
        if self.cache_ram and im is None:
            im = cv2.imread(f)
        elif self.cache_disk:
            if not fn.exists():  # load npy
                np.save(fn.as_posix(), cv2.imread(f))
            im = np.load(fn)
        else:  # read image
            im = cv2.imread(f)  # BGR
        return im

    def _load_target(self, id: int):
        f, t, fn, im = self.samples[id]
        if os.path.isfile(t):
            with open(t) as f:
                lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
                lb = np.array(lb, dtype=np.float32)
        else:
            raise FileNotFoundError(f'{t} does not exist')
        return lb

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        id = self.index[index]
        image = self._load_image(id)
        target = self._load_target(id)

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

总的思路: 读取图片地址,label地址,指定索引。然后定义两个文件,label加载器就行了。

效果:

重写yolo数据加载模块_第1张图片

mosaic:

self.mosaic = True

初始化方法加一个

        if self.mosaic:
            image, target = self.load_mosaic(id, self.img_size)
__getitem__ 加一句
load_mosaic:yolo源码拷贝下来就行了,只需要改一下加载图片和标签

 mixup:

同样拷贝下来源码就可以了

 总结

其他一些增强的方法,都差不多。主要就是索引处理,标签和图片加载。性能肯定比yolo原生的差多了,yolo原生标签,图片在初始化就统一处理了,然后缓存。所以代码量非常大,可阅读性很差。另外torchvision.datasets这个包提供了一套编码规范逻辑很清晰。初学的话,照这个来,我觉得是比较好的选择。

你可能感兴趣的:(人工智能,深度学习,计算机视觉)