Tensorflow(3)-image

本文简要介绍一下tensorflow对image的预处理.->官方API

场景:对于训练集,采用SGD训练。每一轮迭代前,对这批图像预处理(随机增强),然后洗牌,开始训练。这个操作,图像预处理可以并行的放在cpu上计算,因此对模型不会带来太多的额外开销。

1)tf.random_crop()
parameters:
image: 3-D [height, width, channels]
size: [height, width]
seed
name

随机一个offset,然后按照size进行裁剪。offset满足均值分布。

2)tf.image.random_flip_left_right()
parameters:
image
seed

1/2的概率水平翻转图片。

3)tf.image.random_brightness()
parameters:
image
max_delta: float, 非负
seed

随机调整图片亮度。随机取一个值[-max_delta, max_delta] 然后加到整个图像上。

4)tf.image.random_contrast()
parameters:
image
lower: float
upper: float
seed

随机调整图片对比度。contrast_factor 随机取自 [lower, upper]
(x-mean) * contrast_factor + mean

Tensorflow实际使用中需要通过session

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from scipy.misc import imread

IMAGE_SIZE = 200

img = imread('/cat.jpg')

reshaped_image = tf.cast(img, tf.float32)
distorted_image = tf.random_crop(reshaped_image, [IMAGE_SIZE,IMAGE_SIZE,3])
distorted_image = tf.image.random_flip_left_right(distorted_image)
distorted_image = tf.image.random_brightness(distorted_image,
                                               max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image,
                                             lower=0.2, upper=1.8)

rtval = tf.cast(distorted_image, tf.uint8)

with tf.Session() as sess:

    plt.imshow(sess.run(rtval))
    plt.show()

Tensorflow(3)-image_第1张图片

你可能感兴趣的:(机器学习,tensorflow)