pytorch一致数据增强—异用增强

前作 [1] 介绍了一种用 pytorch 模仿 MONAI 实现多幅图(如:image 与 label)同用 random seed 保证一致变换的写法,核心是 MultiCompose 类和 to_multi 包装函数。不过 [1] 没考虑不同图用不同 augmentation 的情况,如:

  1. ColorJitter 只对 image 做,而不对 label 做;
  2. image 的 resize interpolation 可任选,但 label 只能用 nearest

本篇更新写法,支持各图同用、异用 augmentation。

Code

  • 对比 [1],主要改变是改写 MultiCompose 类,并将 to_multi 吸收入内。
  • MultiCompose 的用法还是和 torchvision.transforms.Compose 几乎一致,不过支持异用 augmentation:只要为各图指定各自的 augmentation 类/函数即可。见下一节例程。
def to_multi():
	"""不用单独的 to_multi 打包了,已并入 MultiCompose"""
	raise NotImplementedError


class MultiCompose:
    """Extension of torchvision.transforms.Compose that accepts multiple inputs
    and ensures the same random seed is applied on each of these inputs at each transforms.
    This can be useful when simultaneously transforming images & segmentation masks.
    """

    # numpy.random.seed range error:
    #   ValueError: Seed must be between 0 and 2**32 - 1
    MIN_SEED = 0 # - 0x8000_0000_0000_0000
    MAX_SEED = min(2**32 - 1, 0xffff_ffff_ffff_ffff)

    def __init__(self, transforms):
        # self.transforms = [to_multi(t) for t in transforms]
        no_op = lambda x: x # i.e. identity function
        self.transforms = []
        for t in transforms:
            if isinstance(t, (tuple, list)):
            	# convert `None` to `no_op` for convenience
                self.transforms.append([no_op if _t is None else _t for _t in t])
            else:
                self.transforms.append(t)

    def __call__(self, *images):
        for t in self.transforms:
            if isinstance(t, (tuple, list)):
                assert len(images) <= len(t) # allow redundant transform
            else:
                t = [t] * len(images)

            _aug_images = []
            _seed = random.randint(self.MIN_SEED, self.MAX_SEED)
            for _im, _t in zip(images, t):
                seed_everything(_seed)
                _aug_images.append(_t(_im))

            images = _aug_images

        if len(images) == 1:
            images = images[0]
        return images

Usage & Test

例程沿用 [1],但改一下 augmentation:

train_trans = MultiCompose([
	# image 用 bilinear,label 用 nearest
    (ResizeZoomPad((224, 256), "bilinear"), ResizeZoomPad((224, 256), "nearest")), # 异用
    transforms.RandomAffine(30, (0.1, 0.1)), # 同用,传一个就行
    transforms.RandomHorizontalFlip(), # 同用
    # ColorJitter 只对 image 做,label 不做(None)
    [transforms.ColorJitter(0.1, 0.2, 0.3, 0.4), None], # 异用
])
  • 效果:

pytorch一致数据增强—异用增强_第1张图片

References

  1. pytorch一致数据增强

你可能感兴趣的:(机器学习,pytorch,python,torchvision,数据增强,random)