先看个例子,来自stackoverflow
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import random
# load mnist
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)
# Tensorflow random angle rotation
input_size = mnist.train.images.shape[1]
side_size = int(np.sqrt(input_size))
dataset = tf.placeholder(tf.float32, [None, input_size])
images = tf.reshape(dataset,(-1, side_size, side_size, 1))
random_angles = tf.random.uniform(shape=(tf.shape(images)[0],), minval=-np.pi/4, maxval=np.pi/4)
rotated_images = tf.contrib.image.transform(
images,
tf.contrib.image.angles_to_projective_transforms(
random_angles, tf.cast(tf.shape(images)[1], tf.float32), tf.cast(tf.shape(images)[2], tf.float32)
))
# Run and Print
sess = tf.Session()
result = sess.run(rotated_images, feed_dict = {
dataset: mnist.train.images,
})
original = np.reshape(mnist.train.images * 255, (-1, side_size, side_size)).astype(
np.uint8)
rotated = np.reshape(result * 255, (-1, side_size, side_size)).astype(np.uint8)
# Print 10 random samples
fig, axes = plt.subplots(2, 10, figsize = (15, 4.5))
choice = np.random.choice(range(len(mnist.test.labels)), 10)
for k in range(10):
axes[0][k].set_axis_off()
axes[0][k].imshow(original[choice[k, ]], interpolation = 'nearest', cmap = 'gray')
axes[1][k].set_axis_off()
axes[1][k].imshow(rotated[choice[k, ]], interpolation = 'nearest', cmap = 'gray')
plt.show()
效果如下
再自己实现:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import numpy as np
import tensorflow as tf
import random
import matplotlib.pyplot as plt
def read_tfrecord_use_queue_runner(filename, batch_size=32, image_shape=(224, 224, 3)):
filequeue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, example_tensor = reader.read(filequeue)
example_features = tf.parse_single_example(
example_tensor,
features={
'image/label': tf.FixedLenFeature([], dtype=tf.string),
'image/height': tf.FixedLenFeature([], dtype=tf.int64),
'image/width': tf.FixedLenFeature([], dtype=tf.int64),
'image/encoded': tf.FixedLenFeature([], dtype=tf.string)
}
)
height = tf.cast(example_features['image/height'], tf.int32)
width = tf.cast(example_features['image/width'], tf.int32)
label = tf.cast(example_features['image/label'], tf.string)
image = tf.image.decode_jpeg(example_features['image/encoded'], channels=image_shape[2])
image = tf.reshape(image, tf.stack(image_shape))
min_after_dequeue = 2000
capacity = min_after_dequeue + 3 * batch_size
image_batch, label_batch = tf.train.shuffle_batch(
[image, label],
batch_size=batch_size,
min_after_dequeue=min_after_dequeue,
capacity=capacity,
num_threads=8
)
#### 以下为数据增强部分
# 随机旋转
random_angles = tf.random.uniform(shape=(tf.shape(image_batch)[0],), minval=-np.pi/6, maxval=np.pi/6)
rotated_images = tf.contrib.image.transform(
image_batch,
tf.contrib.image.angles_to_projective_transforms(
random_angles, tf.cast(tf.shape(image_batch)[1], tf.float32), tf.cast(tf.shape(image_batch)[2], tf.float32)
))
# 在[lower, upper]的范围随机调整图的对比度。
# rotated_images = tf.image.random_contrast(rotated_images, 0.1, 0.6)
# 在[lower, upper]的范围随机调整图的饱和度。
# rotated_images = tf.image.random_saturation(rotated_images, 0, 5)
# 在[-max_delta, max_delta)的范围随机调整图片的亮度。
# rotated_images = tf.image.random_brightness(rotated_images, max_delta=0.5)
return rotated_images, label_batch
def main():
tfrecord_file = 'validation.tfrecord'
batch_tensor_dict = read_tfrecord_use_queue_runner(tfrecord_file, batch_size=10, image_shape=(96, 200, 3))
sess = tf.Session()
init_op = tf.group(tf.local_variables_initializer(), tf.global_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
max_time_steps = 1
try:
for step in range(max_time_steps):
if coord.should_stop():
break
images, labels = sess.run(batch_tensor_dict)
for i, (image, label) in enumerate(zip(images, labels)):
# image = np.squeeze(image, axis=2)
plt.subplot(5, 2, i + 1)
plt.imshow(image)
print(label)
plt.show()
except tf.errors.OutOfRangeError():
print('Done training')
finally:
coord.request_stop() # send stop message
coord.join(threads) # wait for all
sess.close()
exit()
if __name__ == "__main__":
main()
增强的过程中,可以使得随机的角度在某个范围以内,并服从正态分布
例如batch为128,可以随机角度在(-30°-30°)之间,并集中旋转在0度附近服从正态分布,如:
如何生成正态分布代码可以参照如下:
import tensorflow as tf
import matplotlib.pyplot as plt
w1 = tf.Variable(tf.random_normal([128,], mean=0, stddev=10, seed=1))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sw1 = sess.run(w1)
plt.hist(sw1, bins=100, normed=True)
plt.show()