深度学习模型的鲁棒性(robustness)和泛化性受到训练数据的多样性和数据量所影响。数据增强(data augmentation)是机器学习和深度学习中经常采用的一个方法,其目的是扩大训练样本的数量。
语义分割是计算机视觉一个重要的下游任务,语义分割的数据增强通常需要对图像及其对应的标签做相同的增强处理
本文总结了3种常用的增强方式:(1)旋转,(2)翻转,(3)裁剪。所有操作均采用opencv库进行
首先使用opencv定义数据读取和保存函数。
# 定义数据读取函数
import cv2
import os
import numpy as np
___________________________________________________________
def read_data(file, mode=1):
"""
Args:
file: 数据路径
mode: bool值,若读取3通道则1,读取灰度图则为0
Returns:
"""
if mode == 1:
img = cv2.imread(file)
# print(img.shape)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
if mode == 0:
img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
return img
else:
raise ValueError("mode should be a bool number 1 or 0")
def save_data(save_pth, img):
img = img.astype(np.uint8)
cv2.imwrite(save_pth, img)
旋转操作(Rotate)
def rotate(img, gt, angle=10):
"""
angle: 旋转的角度
"""
img = read_data(img, 1)
gt = read_data(gt, 0)
assert img.shape[:2] == gt.shape[:2]
h, w = img.shape[:2]
center = (w / 2, h / 2)
mat = cv2.getRotationMatrix2D(center, angle, scale=1)
rotated_img = cv2.warpAffine(img, mat, (h, w))
rotated_gt = cv2.warpAffine(gt, mat, (h, w))
return rotated_img, rotated_gt
翻转操作(flip)
# 翻转图片(水平和垂直) 核心函数:cv.flip()
def flip(img, gt, direction=1):
"""
Args:
img:
direction: bool, 1表示水平翻转,0表示垂直翻转
Returns:
"""
# 如果传入的是图像路径:则先调用函数读取图像,再对图像进行翻转
if type(img) == str:
img = read_data(img, 1)
gt = read_data(gt, 0)
assert img.shape[:2] == gt.shape[:2]
# 如果传入的是读取的图像数据,则直接对图像进行翻转
assert img.shape[:2] == gt.shape[:2]
flipped_img = cv2.flip(img, direction)
flipped_gt = cv2.flip(gt, direction)
return flipped_img, flipped_gt
裁剪操作(crop)
# 在输入图像种的左上角,右上角,左下角,右上角,中间分别裁剪大小为(512*512)大小的子图像
def crop(img, gt):
img = read_data(img, 1)
gt = read_data(gt, 0)
assert img.shape[:2] == gt.shape[:2]
# 左上角
upL_subim, upL_subgt = img[:512, :512, :], gt[:512, :512]
# 右上角
upR_subim, upR_subgt = img[:512, -512:, :], gt[:512, -512:]
# 左下角
bottomL_subim, bottomL_subgt = img[-512:, :512, :], gt[-512:, :512]
# 右下角
bottomR_subim, bottomR_subgt = img[-512:, -512:, :], gt[-512:, -512:]
# 中间
(h, w) = img.shape[:2]
h_ctr, w_ctr = int(h/2), int(w/2)
center_subim, center_subgt = img[(h_ctr - 256):(h_ctr + 256), (w_ctr - 256):(w_ctr + 256), :],\
gt[(h_ctr - 256):(h_ctr + 256), (w_ctr - 256):(w_ctr + 256)],
(w - 256):(w + 256)]
return (upL_subim, upL_subgt), (upR_subim, upR_subgt), (bottomL_subim, bottomL_subgt), (
bottomR_subim, bottomR_subgt), (center_subim, center_subgt)