[笔记] Data Augmentation for Computer Vision with PyTorch

数据增强(Data Augmentation)是一种避免模型过拟合、提高模型泛化能力的常见手段,由于种类和细节繁多,很少有人对其进行整理。
本篇博文旨在对常见的数据增强方式进行整理,并利用 torchvision 所提供的函数进行 Python 代码示例。

目录

  1. Basic knowledge About torchvision.transforms
  2. Data Augmentation for Computer Vision

1. Basic knowledge About torchvision.transforms

torchvision 提供子包 transforms,用以进行常见的图像变换(image transformation),常用于图像数据的数据增强。
官方文档:链接

torchvision.transforms 所提供的各个具体的 transforms 类可以利用 torchvision.transforms.Compose() 进行结合,方便代码的编写:

trans = transforms.Compose([ 
            transforms.Resize((args.MultiScale,args.MultiScale)),
            transforms.ColorJitter(brightness=(0.5,1.5), contrast=(0.5,1.5), saturation=(0.5,1.5), hue=(-0.1,0.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
            ])

常见的 transforms 类所处理的对象都是 PIL 的 Image 实例,不过 torchvision.transforms 仍然提供了可以处理其他数据的 transforms:

  • Transforms on torch.*Tensor
# 利用该 transfroms 类对 Tensor 进行归一化
trans = torchvision.transforms.Normalize(mean, std, inplace=False)
  • Conversion Transforms
# 利用该 transfroms 类将 Tensor 转化为 PIL 的 Image 实例
trans = torchvision.transforms.ToPILImage(mode=None)
# 利用该 transfroms 类将 PIL 的 Image 实例转化为 Tensor
trans = torchvision.transforms.ToTensor()
  • Generic Transforms
# 该 transfroms 类支持将 lambda function 转为 transform,常与 transforms.Compose() 结合使用
trans = torchvision.transforms.Lambda(lambda_function)

2. Data Augmentation for Computer Vision

  • 光度畸变(Photometric Distortions)

    • 随机色彩抖动(Random Color Jitter)
    # 其中,brightness 指定亮度抖动范围,contrast 指定对比度抖动范围,saturation 指定饱和度抖动范围,hue 指定色调抖动范围。
    # 参数格式为:float or tuple of python:float (min, max),分别对应的抖动范围为[max(0,1-param),1+param] or the given [min, max]。
    # brightness/contrast/saturation 要求非负, 而 hue 则要求 0<=hue<=0.5 or -0.5<=min<=max<=0.5$。
    trans = torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
    
    # 还可以用 torchvision.transforms.functional 所提供的函数进行色彩抖动
    trans = torchvision.transforms.functional.adjust_brightness(img, brightness_factor) # 调节图像亮度
    trans = torchvision.transforms.functional.adjust_contrast(img, contrast_factor) # 调节图像对比度
    trans = torchvision.transforms.functional.adjust_saturation(img, saturation_factor) # 调节图像饱和度
    trans = torchvision.transforms.functional.adjust_hue(img, hue_factor) # 调节图像色调
    trans = torchvision.transforms.functional.adjust_gamma(img, gamma, gain=1) # 进行伽马矫正
    
    • 随机通道(Random Channel Shffule)
    ChannelShuufle = [Lambda(lambda img: Image.merge('RGB', (img.split[0], img.split[1], img.split[2]))),
    Lambda(lambda img: Image.merge('RGB', (img.split[0], img.split[2], img.split[1]))),
    Lambda(lambda img: Image.merge('RGB', (img.split[1], img.split[0], img.split[1]))),
    Lambda(lambda img: Image.merge('RGB', (img.split[1], img.split[2], img.split[0]))),
    Lambda(lambda img: Image.merge('RGB', (img.split[2], img.split[0], img.split[1]))),
    Lambda(lambda img: Image.merge('RGB', (img.split[2], img.split[1], img.split[0])))]
    trans = torchvision.transforms.RandomChoice(ChannelShffule)
    
    • 随机转化为灰度图(Random Grayscale)
    # 其中,p 为随机转化概率,所输出 Tensor 的通道数与输入保持一致
    trans = torchvision.transforms.RandomGrayscale(p=0.1)
    # 利用 RandomApply 可以做到指定所输出 Tensor 的通道数
    # 其中,p 为随机概率,num_output_channels 指定所输出 Tensor 的通道数
    trans = torchvision.transforms.RandomApply(torchvision.transforms.Grayscale(num_output_channels=1), p=0.5)
    
    • 归一化(Normalization)
    # 其中,Iamge Net 常用参数为:mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]
    # PS:需要先将图像转为 Tensor
    trans = torchvision.transforms.Compose([ 
            transforms.ToTensor(),
            transforms.Normalize(mean, std, inplace=False),
            ])
    
    • 未完待续
  • 几何畸变(Geometric Distortions)

    • 调整图像大小(Resize Image)
    # 其中,size 为缩放后图像的大小,格式为:(sequence or int); interpolation 指定插值方式,为可选参数,格式为:int(默认值为 PIL.Image.BILINEAR)。
    # size 为 (h,w) 时,缩放后的图像大小为 (h,w);size 为 int 时,图像将保持原有宽高比进行缩放且最短边将缩放至 size 大小,即 (size * height/width, size)(当原图像高比宽长时)。
    trans = torchvision.transforms.Resize(size, interpolation=2)
    
    • 扩展边界框(Expand Bounding Box)
    def EnlargeBBox(img,bbox,Factor_H,Factor_V):
    	ori_img_w, ori_img_h = img.size
    
        # BoundingBox
        left  = bbox[0]
        upper = bbox[1]
        right = bbox[2]
        lower = bbox[3]
        
        # Set Enlarge Facotr
        padding_h = max(0, int((right - left)*Factor_H)) 
        padding_v = max(0, int((lower - upper)*Factor_V))
        
        # Enlarge Bounding Box
    	left  = max(left - int(0.5*padding_h), 0)
        right = min(right + int(0.5*padding_h), ori_img_w)
        upper = max(upper - int(0.5*padding_v), 0)
        lower = min(lower + int(0.5*padding_v), ori_img_h)
        
    	# return (left,right,upper,lower)
    	return left, upper, lower-upper, right-left
    
    # 利用 torchvision.transforms.functional.crop 进行 ROI 提取
    patch = torchvision.transforms.functional.crop(img, *EnlargeBBox(img,bbox,Factor_H,Factor_V))
    
    • 中心裁剪(Center Crop)
    # 其中,size 指定所裁剪区域大小,参数格式为:int or tuple of python
    # 当 size 为 int 时,所裁剪的区域大小为 (size,size)
    trans = torchvision.transforms.CenterCrop(size)
    
    # 可以用 torchvision.transforms.functional.crop(img, i, j, h, w) 进行精细裁剪
    # 其中,(i,j) 为裁剪区域的左上角坐标,(h,w) 为裁剪区域的高与宽
    patch = torchvision.transforms.functional.crop(img, i, j, h, w)
    
    • 随机裁剪(Random Crop)
    trans = torchvision.transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')
    
    • 随机旋转(Random Rotation)
    # 其中,degrees 指定随机旋转的角度范围,格式为:(sequence or float or int) 。
    trans = torchvision.transforms.RandomRotation(degrees, resample=False, expand=False, center=None)
    
    • 随机翻转(Random Flip)

      • 随机垂直翻转
      trans = torchvision.transforms.RandomVerticalFlip(p=0.5) # 其中,p 为翻转概率
      
      • 随机水平翻转
      trans = torchvision.transforms.RandomHorizontalFlip(p=0.5) # 其中,p 为翻转概率
      
    • 未完待续

参考资料:

  • Image augmentation for machine learning experiments.
  • Torchvision transforms 总结
  • 深度学习入门之Pytorch——数据增强

如果你看到了这篇文章的最后,并且觉得有帮助的话,麻烦你花几秒钟时间点个赞,或者受累在评论中指出我的错误。谢谢!

作者信息:
知乎:没头脑
LeetCode:Tao Pu
CSDN:Code_Mart
Github:Bojack-want-drink

你可能感兴趣的:(PyTorch,Work,工作成长)