[TensorFlow实战] 图片预处理

代码

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

WIDTH=456
HEIGHT=730

def distort_color(image,color_ordering = 0):
    """ image color preprocess
    Args:
        image: 
            input image to be processed
        color_ordering: 
            process order
    return:
        processed image tensor
    """
    if color_ordering == 0:
        image = tf.image.random_brightness(image,max_delta=32./255.)
        image = tf.image.random_saturation(image,lower=0.5,upper=1.5)
        image = tf.image.random_hue(image,max_delta=0.2)
        image = tf.image.random_contrast(image,lower=0.5,upper=1.5)
    elif color_ordering == 1:
        image = tf.image.random_saturation(image,lower=0.5,upper=1.5)
        image = tf.image.random_brightness(image,max_delta=32./255.)
        image = tf.image.random_contrast(image,lower=0.5,upper=1.5)
        image = tf.image.random_hue(image,max_delta=0.2)
    else:
        #define new settings here
        pass
    return tf.clip_by_value(image,0.0,1.0)

def preprocess_for_train(image,height,width,bbox=None):
    """preprocess one image from dataset before sending it to
    the neural networks
    Args:
        image:
            original image
        height,width:
            size of the original image
        bbox:
            bounding box that indicates the area needs to be processed
    return:
        processed image tensor
    """
    # if the bbox is not assigned, default set it to the whole image
    if bbox is None:
        bbox = tf.constant([0.0,0.0,1.0,1.0],dtype=tf.float32,shape=[1,1,4])
    # make sure the data type of input image tensor is float32
    if image.dtype != tf.float32:
        image=tf.image.convert_image_dtype(image,dtype=tf.float32)

    # sample the image randomly 
    bbox_begin,bbox_size,_ = tf.image.sample_distorted_bounding_box(
    tf.shape(image),bounding_boxes=bbox)
    distorted_image = tf.slice(image,bbox_begin,bbox_size)
    distorted_image = tf.image.resize_images(distorted_image,
        (height,width),method=np.random.randint(4))

    # random left&right flip    
    distorted_image = tf.image.random_flip_left_right(distorted_image)
    # random color preprocess
    distorted_image = distort_color(distorted_image,np.random.randint(2))

    return distorted_image

def main():
    image_raw_data = tf.gfile.FastGFile(r'cat.jpg','r').read()
    with tf.Session() as sess:
        img_data=tf.image.decode_jpeg(image_raw_data)
        for i in range(10):
            res = preprocess_for_train(img_data,WIDTH,HEIGHT)
            plt.imshow(res.eval())
            plt.savefig(str(i)+".jpg")
            plt.show()

if __name__=='__main__':
        main()

效果图

[TensorFlow实战] 图片预处理_第1张图片
[TensorFlow实战] 图片预处理_第2张图片
[TensorFlow实战] 图片预处理_第3张图片
[TensorFlow实战] 图片预处理_第4张图片

你可能感兴趣的:(机器学习)