KERAS-YOLOV3的数据增强

前言

上篇KERAS-YOLOV3的代码走读
https://blog.csdn.net/yangchengtest/article/details/80664415
有数据增强的内容没有看明白。
这篇来介绍一下。

简介

数据增强的方法主要有:
1. 翻转变换 flip
2. 随机修剪 random crop
3. 色彩抖动 color jittering
4. 平移变换 shift
5. 尺度变换 scale
6. 对比度变换 contrast
7. 噪声扰动 noise
8. 旋转变换/反射变换 Rotation/reflection

KERAS-YOLOV3源码分析

def get_random_data(annotation_line, input_shape, random=True, max_boxes=20, jitter=.3, hue=.1, sat=1.5, val=1.5, proc_img=True)
def rand(a=0, b=1):
return np.random.rand()*(b-a) + a
rand返回两个入参之间的随机数。

缩放图片

# resize image
# 随机生成宽高比
new_ar = w/h * rand(1-jitter,1+jitter)/rand(1-jitter,1+jitter)
# 随机生成缩放比例。
scale = rand(.25, 2)
# 生成新的高宽数据,可能放大2倍。
if new_ar < 1:
    nh = int(scale*h)
    nw = int(nh*new_ar)
else:
    nw = int(scale*w)
    nh = int(nw/new_ar)
image = image.resize((nw,nh), Image.BICUBIC)

平移变换

# place image
# 随机水平位移
dx = int(rand(0, w-nw))
dy = int(rand(0, h-nh))
new_image = Image.new('RGB', (w,h), (128,128,128))
new_image.paste(image, (dx, dy))
image = new_image

翻转

# flip image or not
flip = rand()<.5
if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)

颜色抖动

RGB->HSV->RGB

# distort image
# HSV抖动
hue = rand(-hue, hue)
sat = rand(1, sat) if rand()<.5 else 1/rand(1, sat)
val = rand(1, val) if rand()<.5 else 1/rand(1, val)
# 归一化处理
# 内部函数,通过公式转化。具体函数不介绍。
x = rgb_to_hsv(np.array(image)/255.)
x[..., 0] += hue
x[..., 0][x[..., 0]>1] -= 1
x[..., 0][x[..., 0]<0] += 1
x[..., 1] *= sat
x[..., 2] *= val
# 避免S/V CHANNEL越界
x[x>1] = 1
x[x<0] = 0
image_data = hsv_to_rgb(x) # numpy array, 0 to 1

定义新的BOX位置

YOLO是位置检测的算法,在经过缩放和水平变换后,BOX的左边也需要相应的变化。

# correct boxes
box_data = np.zeros((max_boxes,5))
if len(box)>0:
    np.random.shuffle(box)
    box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
    box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
    ### 左右翻转
    if flip: box[:, [0,2]] = w - box[:, [2,0]]
    ### 定义边界
    box[:, 0:2][box[:, 0:2]<0] = 0
    box[:, 2][box[:, 2]>w] = w
    box[:, 3][box[:, 3]>h] = h
    ### 计算新的长宽
    box_w = box[:, 2] - box[:, 0]
    box_h = box[:, 3] - box[:, 1]
    box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
    if len(box)>max_boxes: box = box[:max_boxes]
    box_data[:len(box)] = box

你可能感兴趣的:(位置定位算法)