YOLOv5添加自定义数据增广方法

YOLOv5添加自定义数据增广方法

虽然YOLOv5内置数据增广方法非常丰富,包括随机旋转、翻转、HSV-Saturation等。但仍然有添加自定义的数据增广方法的情况。例如使用N+L策略训练网络。N+L表示Normal resolution和Low resolution混合数据集。这时就需要在YOLOv5中添加退化算法。那么这里就以一种用于盲超分模型的退化算法[1]为例,以下称为B-DEGRADE,源代码来源于BSRGAN项目,展示一哈如何在YOLOv5项目中添加自定义数据增广方法。

YOLOv5的Dataloader与Dataset

YOLOv5 在/yolov5/utils/dataset.py中创建Dataloader与Dataset实例。like below。Dataloader使用直接使用pytorch Dataloader类或者基于Dataloader创建子类。Dataloader负责在训练和验证时产生batch迭代器。该类的一个重要成员变量就是Dataset。Dataset具有加载数据、按索引获取数据等功能。数据增广Dataset获取数据时完成!

def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
                      rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False):
    if rect and shuffle:
        LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
        shuffle = False
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
        # 创建Dataset实例
        dataset = LoadImagesAndLabels(path, imgsz, batch_size,
                                      augment=augment,  # augmentation
                                      hyp=hyp,  # hyperparameters
                                      rect=rect,  # rectangular batches
                                      cache_images=cache,
                                      single_cls=single_cls,
                                      stride=int(stride),
                                      pad=pad,
                                      image_weights=image_weights,
                                      prefix=prefix)

    batch_size = min(batch_size, len(dataset))
    nd = torch.cuda.device_count()  # number of CUDA devices
    nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])  # number of workers
    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
    # 使用pytorch自带的Dataloader或者使用基于Dotaloader创建的子类实例
    loader = DataLoader if image_weights else InfiniteDataLoader  # only DataLoader allows for attribute updates
    return loader(dataset,
                  batch_size=batch_size,
                  shuffle=shuffle and sampler is None,
                  num_workers=nw,
                  sampler=sampler,
                  pin_memory=True,
                  collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset

YOLOv5 Dataset是LoadImagesAndLabels类的实例。LoadImagesAndLabels类在调用构造函数时完成数据集加载,可以在__init__方法中看到加载图像位置与标签位置的代码,like below。cache_labels方法读取标签文件生成标签矩阵。在cache_labels方法中调用了一个需要特别注意的方法verify_image_label,这个方法用来验证某个图像是否具有标签,(Attetion: 特别注意这里,后面要用到哦!)如果图像没有标签,那么该图像索引对应的标签被赋值为np.zero((0,5), dtype=np.float32)

    def __init__(self,...):
        ...
        try:
            # 读取图像文件位置
            f = []  # image files
            for p in path if isinstance(path, list) else [path]:
                p = Path(p)  # os-agnostic
                if p.is_dir():  # dir
                    f += glob.glob(str(p / '**' / '*.*'), recursive=True)
                    # f = list(p.rglob('*.*'))  # pathlib
                elif p.is_file():  # file
                    with open(p) as t:
                        t = t.read().strip().splitlines()
                        parent = str(p.parent) + os.sep
                        f += [x.replace('./', parent) if x.startswith('./') else x for x in t]  # local to global path
                        # f += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib)
                else:
                    raise Exception(f'{prefix}{p} does not exist')
            self.img_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
            # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS])  # pathlib
            assert self.img_files, f'{prefix}No images found'
        except Exception as e:
            raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
        
        # 查看是否有标签数据cache文件,如果没有则读取使用cache_labels方法读取label文件
        # Check cache
        # 根据图像文件位置生成标签文件位置
        self.label_files = img2label_paths(self.img_files)  # labels
        cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
        try:
            cache, exists = np.load(cache_path, allow_pickle=True).item(), True  # load dict
            assert cache['version'] == self.cache_version  # same version
            assert cache['hash'] == get_hash(self.label_files + self.img_files)  # same hash
        except Exception:
            cache, exists = self.cache_labels(cache_path, prefix), False  # cache
def verify_image_label(args):
    # Verify one image-label pair
    im_file, lb_file, prefix = args
    nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', []  # number (missing, found, empty, corrupt), message, segments
    try:
        # verify images
        im = Image.open(im_file)
        im.verify()  # PIL verify
        shape = exif_size(im)  # image size
        assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
        assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
        if im.format.lower() in ('jpg', 'jpeg'):
            with open(im_file, 'rb') as f:
                f.seek(-2, 2)
                if f.read() != b'\xff\xd9':  # corrupt JPEG
                    ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
                    msg = f'{prefix}WARNING: {im_file}: corrupt JPEG restored and saved'

        # verify labels
        if os.path.isfile(lb_file):
            nf = 1  # label found
            with open(lb_file) as f:
                lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
                if any([len(x) > 8 for x in lb]):  # is segment
                    classes = np.array([x[0] for x in lb], dtype=np.float32)
                    segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb]  # (cls, xy1...)
                    lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1)  # (cls, xywh)
                lb = np.array(lb, dtype=np.float32)
            nl = len(lb)
            if nl:
                assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
                assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
                assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
                _, i = np.unique(lb, axis=0, return_index=True)
                if len(i) < nl:  # duplicate row check
                    lb = lb[i]  # remove duplicates
                    if segments:
                        segments = segments[i]
                    msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
            else:
                # 图像不存在label时,将label定义为如下形式
                ne = 1  # label empty
                lb = np.zeros((0, 5), dtype=np.float32)
        else:
            # label缺失时,将label定义为如下形式
            nm = 1  # label missing
            lb = np.zeros((0, 5), dtype=np.float32)
        return im_file, lb, shape, segments, nm, nf, ne, nc, msg
    except Exception as e:
        nc = 1
        msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}'
        return [None, None, None, None, nm, nf, ne, nc, msg]

YOLOv5的数据增广方法

LoadImagesAndLabels类在按索引获取数据时完成数据增广,方法为__getitem__(self, index)。YOLOv5内置的数据增广方式包括Mosaic、随机旋转、随机翻转、HSV等。Mosaic默认打开。因此我们可以直接进入load_mosaic中添加自己自定义的数据增广方法。

    def __getitem__(self, index):
        index = self.indices[index]  # linear, shuffled, or image_weights

        hyp = self.hyp
        mosaic = self.mosaic and random.random() < hyp['mosaic']
        if mosaic: # mosaic默认打开,所以直接进入load_mosaic方法即可
            # Load mosaic
            img, labels = self.load_mosaic(index)
            shapes = None
        ...
    def load_mosaic(self, index):
        # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
        labels4, segments4 = [], []
        s = self.img_size
        yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border)  # mosaic center x, y
        # 随机产生另外3个图片,与指定图片拼接在一起
        indices = [index] + random.choices(self.indices, k=3)  # 3 additional image indices
        random.shuffle(indices)
        for i, index in enumerate(indices):
            # Load image
            img, _, (h, w) = self.load_image(index)

            # place img in img4
            if i == 0:  # top left
                img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
                x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc  # xmin, ymin, xmax, ymax (large image)
                x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h  # xmin, ymin, xmax, ymax (small image)
            elif i == 1:  # top right
                x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
                x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
            elif i == 2:  # bottom left
                x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
                x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
            elif i == 3:  # bottom right
                x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
                x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)

            img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b]  # img4[ymin:ymax, xmin:xmax]
            padw = x1a - x1b
            padh = y1a - y1b

            # Labels
            labels, segments = self.labels[index].copy(), self.segments[index].copy()
            if labels.size:
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh)  # normalized xywh to pixel xyxy format
                segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
            labels4.append(labels)
            segments4.extend(segments)

        # Concat/clip labels
        # mosaic后,标签数值也发生相应变化
        labels4 = np.concatenate(labels4, 0)
        for x in (labels4[:, 1:], *segments4):
            np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
        # img4, labels4 = replicate(img4, labels4)  # replicate

        # Augment
        # 对拼接后的图像进行数据增广
        img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
        img4, labels4 = random_perspective(img4, labels4, segments4,
                                           degrees=self.hyp['degrees'],
                                           translate=self.hyp['translate'],
                                           scale=self.hyp['scale'],
                                           shear=self.hyp['shear'],
                                           perspective=self.hyp['perspective'],
                                           border=self.mosaic_border)  # border to remove

        return img4, labels4

将自己的增广方法添加到YOLOv5中

Now,我们尝试将BSRGAN中自带的退化方法B-DEGRADE为添加到YOLOv5中。目标是在将batch投入训练前,按照一定概率(设当 P < 0.3 P<0.3 P<0.3时)使图片退化。then,Datasetloader怎么获取一个batch的?根据上面的代码分析,当然是通过Dataset实例的__getitem__方法啦。那么我们只需要在getitem读取完图片后,马上退化图片即可儿。因为getitem通过调用self.load_mosaic加载图片,那么我们就在load_mosaic中添加自己的退化方法!所以we have

# 首先记得import我们的增广方法b_degrade
from BSRGAN import utils.utils_blindsr.degradation_bsrgan as b_degrade
...

    def __getitem__(...):
      ...
      img, labels = self.load_mosaic(index)
      ...
    
    def load_mosaic(...):
        labels4, segments4 = [], []
        s = self.img_size
        yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border)  # mosaic center x, y
        # 随机产生另外3个图片,与指定图片拼接在一起
        indices = [index] + random.choices(self.indices, k=3)  # 3 additional image indices
        random.shuffle(indices)
        for i, index in enumerate(indices):
            # Load image
            img, _, (h, w) = self.load_image(index)

            ### Start: 添加我们的增广方法b_degrade ###

            degrade_flag = False
            if self.augment:
              if random.random() < 0.3:
                img, _ = b_degrade(img, sf=2) # 将退化算法中下采样率设置为2
                degrade_flag = True
            n, m, _ = img.shape
            
            labels = self.labels[index].copy() # 提前加载labels
            # 由于退化中含有下采样,应当剔除退化后无法识别的目标。
            if degrade_flag:
              labels = clear_cant_detect_object(labels, n, m, threshold=783) # 按照设定的阈值剔除无法识别的目标儿

            ### End: 添加我们的增广方法b_degrade ###

            ...

参考文献

[1]. Zhang, Kai, et al. “Designing a practical degradation model for deep blind image super-resolution.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021.

你可能感兴趣的:(数据增广,YOLOv5,pytorch,深度学习,目标检测)