图像数据增强

一、做随机亮度、对比度、饱和度修改,使用tensorflow API
核心部分是 aug_op 函数,这可是菜鸟的心血啊!

#coding:utf-8
import tensorflow as tf
import cv2
import random
import sys
import os
import shutil

#os.environ["CUDA_VISIBLE_DEVICES"] = ""

def random_normal(img, mean=0.0, stddev=2.0, limit=6):
    img = tf.cast(img, dtype=tf.int32)# This type cast is very important.
    noise = tf.random_normal(tf.shape(img), mean, stddev)
    noise = tf.cast(noise, dtype=tf.int32)# This type cast is very important. For negative value.
    up_limit = tf.where(tf.less_equal(noise, limit), tf.ones_like(noise, dtype=tf.uint8), tf.zeros_like(noise, dtype=tf.uint8))
    down_limit = tf.where(tf.greater_equal(noise, 0-limit), tf.ones_like(noise, dtype=tf.uint8), tf.zeros_like(noise, dtype=tf.uint8))
    up_limit = tf.cast(up_limit, dtype=tf.int32)
    down_limit = tf.cast(down_limit, dtype=tf.int32)

    noise = (up_limit*down_limit)*noise
    new_img = tf.add(img, noise)
    new_img = tf.cast(img, dtype=tf.uint8)
    return new_img

def aug_op(img):
    new_img = img
    cnt = 0

    a = tf.constant([0,1],dtype=tf.int32)
    b = tf.random_shuffle(a)
    cnt = tf.constant([0], dtype=tf.int32)

    new_img, cnt = tf.cond(b[0]>0, lambda: (tf.image.random_brightness(new_img, 0.3), cnt+1), lambda: (new_img,cnt))
    new_img, cnt = tf.cond(b[0]>0, lambda: (tf.image.random_contrast(new_img, 0.5, 2), cnt+1), lambda: (new_img,cnt))
    new_img, cnt = tf.cond(b[0]>0, lambda: (tf.image.random_saturation(new_img, 0.3, 2), cnt+1) , lambda: (new_img,cnt))
    new_img, cnt = tf.cond(b[0]>0, lambda: (random_normal(new_img), cnt+1), lambda: (new_img,cnt+1))

    return new_img, cnt

if __name__ == "__main__":
    
    src_dir = sys.argv[1]
    src_landmark_dir = sys.argv[2]
    dst_dir = sys.argv[3]
    dst_landmark_dir = sys.argv[4]
    loop_times = 5

    if not os.path.exists(src_dir):
        print("{} not exists,exit.".format(src_dir))
    if not os.path.exists(src_landmark_dir):
        print("{} not exists,exit.".format(src_landmark_dir))
    if not os.path.exists(dst_dir):
        os.makedirs(dst_dir)
    if not os.path.exists(dst_landmark_dir):
        os.makedirs(dst_landmark_dir)
    
    sess = tf.Session()

    input_image = tf.placeholder(tf.uint8, shape=[None, None, 3], name="input_image")
    op, cnt_op = aug_op(input_image)

    src_names = os.listdir(src_dir)
    cnt = 0
    total_cnt = len(src_names)
    new_cnt = 0
    for name in src_names:
        
        img = cv2.imread(os.path.join(src_dir, name))
        if img is None:
            print("Image:{} read none, skip".format(os.path.join(src_dir, name)))
            continue

        pre, ext = os.path.splitext(name)
        pts_name = pre+".pts"
        
        for i in range(loop_times):

            new_name = pre+"_"+str(i)+ext
            new_pts_name = pre+"_"+str(i)+".pts"

            new_img,  isaug= sess.run([op,cnt_op], feed_dict={input_image:img})
            print(isaug)
            if isaug == 0:
                continue

            cv2.imwrite(os.path.join(dst_dir, new_name), new_img)
            shutil.copy(os.path.join(src_landmark_dir, pts_name), os.path.join(dst_landmark_dir, new_pts_name))
            new_cnt = new_cnt + 1
        cnt = cnt + 1
        print("progress: {}/{}, new image:{}".format(cnt, total_cnt, new_cnt))
    

    #python img_aug.py /home/hongjie.li/FaceAlign/dataset-landmark/20181018-1000-camera /home/hongjie.li/FaceAlign/dataset-landmark/landmark_all /home/hongjie.li/FaceAlign/dataset-landmark/20181018-1000-camera-aug /home/hongjie.li/FaceAlign/dataset-landmark/20181018-1000-camera-aug-landmark
 

你可能感兴趣的:(图像数据增强)