from __future__
import absolute_import
from __future__
import division
from __future__
import print_function
from datetime
import datetime
import os
import random
import sys
import threading
import numpy
as np
import tensorflow
as tf
class TFRecordsGenerator(object):
"""
this class is using for tf_records generations in image classification use
For usages:
All images must contains in different folders, TFRecordsGenerator will traverse
all folders and find different classes.
"""
def __init__(self,
name,
images_dir,
classes_file_path,
tf_records_save_dir,
num_shards=
4,
num_threads=
4):
self.name = name
self.classes_file_path = classes_file_path
self.images_dir = images_dir
self.tf_records_saved_dir = tf_records_save_dir
self.num_shards = num_shards
self.num_threads = num_threads
@staticmethod
def _int64_feature(value):
if
not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
@staticmethod
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _convert_to_example(self, filename, image_buffer, label, text, height, width):
"""
Example for image classification
:param filename:
:param image_buffer:
:param label:
:param text:
:param height:
:param width:
:return:
"""
color_space =
'RGB'
channels =
3
image_format =
'JPEG'
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': self._int64_feature(height),
'image/width': self._int64_feature(width),
'image/color_space': self._bytes_feature(tf.compat.as_bytes(color_space)),
'image/channels': self._int64_feature(channels),
'image/class/label': self._int64_feature(label),
'image/class/text': self._bytes_feature(tf.compat.as_bytes(text)),
'image/format': self._bytes_feature(tf.compat.as_bytes(image_format)),
'image/filename': self._bytes_feature(tf.compat.as_bytes(os.path.basename(filename))),
'image/encoded': self._bytes_feature(tf.compat.as_bytes(image_buffer))}))
return example
class ImageCoder(object):
def __init__(self):
self._sess = tf.Session()
self._png_data = tf.placeholder(dtype=tf.string)
image = tf.image.decode_png(self._png_data, channels=
3)
self._png_to_jpeg = tf.image.encode_jpeg(image, format=
'rgb', quality=
100)
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=
3)
def png_to_jpeg(self, image_data):
return self._sess.run(self._png_to_jpeg,
feed_dict={self._png_data: image_data})
def decode_jpeg(self, image_data):
image = self._sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) ==
3
assert image.shape[
2] ==
3
return image
@staticmethod
def _is_png(filename):
return
'.png'
in filename
def _process_image(self, filename, coder):
with tf.gfile.FastGFile(filename,
'r')
as f:
image_data = f.read()
if self._is_png(filename):
print(
'Converting PNG to JPEG for %s' % filename)
image_data = coder.png_to_jpeg(image_data)
image = coder.decode_jpeg(image_data)
assert len(image.shape) ==
3
height = image.shape[
0]
width = image.shape[
1]
assert image.shape[
2] ==
3
return image_data, height, width
def _process_image_files_batch(self, coder, thread_index, ranges, name, file_names,
texts, labels, num_shards):
num_threads = len(ranges)
assert
not num_shards % num_threads
num_shards_per_batch = int(num_shards / num_threads)
shard_ranges = np.linspace(ranges[thread_index][
0],
ranges[thread_index][
1],
num_shards_per_batch +
1).astype(int)
num_files_in_thread = ranges[thread_index][
1] - ranges[thread_index][
0]
counter =
0
for s
in range(num_shards_per_batch):
shard = thread_index * num_shards_per_batch + s
output_filename =
'%s-%.5d-of-%.5d.tfrecord' % (name, shard, num_shards)
output_file = os.path.join(self.tf_records_saved_dir, output_filename)
writer = tf.python_io.TFRecordWriter(output_file)
shard_counter =
0
files_in_shard = np.arange(shard_ranges[s], shard_ranges[s +
1], dtype=int)
for i
in files_in_shard:
filename = file_names[i]
label = labels[i]
text = texts[i]
image_buffer, height, width = self._process_image(filename, coder)
example = self._convert_to_example(filename, image_buffer, label,
text, height, width)
writer.write(example.SerializeToString())
shard_counter +=
1
counter +=
1
if
not counter %
1000:
print(
'%s [thread %d]: Processed %d of %d images in thread batch.' %
(datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()
writer.close()
print(
'%s [thread %d]: Wrote %d images to %s' %
(datetime.now(), thread_index, shard_counter, output_file))
sys.stdout.flush()
shard_counter =
0
print(
'%s [thread %d]: Wrote %d images to %d shards.' %
(datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()
def _process_image_files(self, file_names, texts, labels):
assert len(file_names) == len(texts)
assert len(file_names) == len(labels)
spacing = np.linspace(
0, len(file_names), self.num_threads +
1).astype(np.int)
ranges = []
for i
in range(len(spacing) -
1):
ranges.append([spacing[i], spacing[i +
1]])
print(
'Launching %d threads for spacings: %s' % (self.num_threads, ranges))
sys.stdout.flush()
coord = tf.train.Coordinator()
coder = self.ImageCoder()
threads = []
for thread_index
in range(len(ranges)):
args = (coder, thread_index, ranges, self.name, file_names,
texts, labels, self.num_shards)
t = threading.Thread(target=self._process_image_files_batch, args=args)
t.start()
threads.append(t)
coord.join(threads)
print(
'%s: Finished writing all %d images in data set.' %
(datetime.now(), len(file_names)))
sys.stdout.flush()
def _find_image_files(self):
print(
'Determining list of input files and labels from %s.' % self.images_dir)
unique_labels = [l.strip()
for l
in tf.gfile.FastGFile(
self.classes_file_path,
'r').readlines()]
labels = []
file_names = []
texts = []
label_index =
1
for text
in unique_labels:
jpeg_file_path =
'%s/%s/*' % (self.images_dir, text)
matching_files = tf.gfile.Glob(jpeg_file_path)
labels.extend([label_index] * len(matching_files))
texts.extend([text] * len(matching_files))
file_names.extend(matching_files)
if
not label_index %
100:
print(
'Finished finding files in %d of %d classes.' % (
label_index, len(labels)))
label_index +=
1
shuffled_index = list(range(len(file_names)))
random.seed(
12345)
random.shuffle(shuffled_index)
file_names = [file_names[i]
for i
in shuffled_index]
texts = [texts[i]
for i
in shuffled_index]
labels = [labels[i]
for i
in shuffled_index]
print(
'Found %d JPEG files across %d labels inside %s.' %
(len(file_names), len(unique_labels), self.images_dir))
print(
'[INFO] Attempting logging out file_names list: {}'.format(
'\n'.join(file_names)))
return file_names, texts, labels
def generate(self):
assert
not self.num_shards % self.num_threads, (
'Please make the FLAGS.num_threads commensurate with FLAGS.train_shards')
print(
'Saving results to %s' % self.tf_records_saved_dir)
file_names, texts, labels = self._find_image_files()
self._process_image_files(file_names, texts, labels)
print(
'All Done! Solved {} images. tf_records file saved into {}.'.format(len(file_names), os.path.abspath(
self.tf_records_saved_dir)))