以分类数据为例。
tfreocrds数据将原始图像数据和标签数据以二进制格式存储。存储内容以如下形式存储:
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
"height": tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0]])),
"width": tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[1]])),
"channel": tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[2]])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
存储代码如下:
"""
2018.7.20 --
Use this code to create the dataset for classification.
first write your own class_dic according to the folder.
then set the root path for all class folders in "data_path"
:parameter classes: the name of folder for different classes' images
:parameter writer: the filename of tfrecord
"""
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import scipy.misc as misc
import numpy as np
cwd = os.getcwd()
class_dic = {
"airplane": 0,
"ship": 1,
"background": 2
}
def tf_record_create(cwd, classes, writer):
"""
:param cwd: filepath which contains all classes(one class, one folder.
:param classes: {'class_name1', 'class_name2', ...}
:param writer: tf.python_io.TFRecordWriter("*****.tfrecords")):
:return:
"""
for index, name in enumerate(classes):
label = class_dic[name]
class_path = cwd+'/'+name+'/'
for img_name in os.listdir(class_path):
img_path = class_path+img_name #每一个图片的地址
img_pil = Image.open(img_path)
img = np.array(img_pil)
# img = img.resize((IMG_HEIGHT, IMG_WIDTH))
# instead of resize, get the image shape and write in the example
shape_debug = img.shape
shape = list(img.shape)
if len(shape) == 2:
shape.append(1)
shape = np.array(shape, np.int64)
img_raw = img.tobytes() # 将图片转化为二进制格式
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
"height": tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0]])),
"width": tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[1]])),
"channel": tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[2]])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) # example对象对label和image数据进行封装
writer.write(example.SerializeToString()) # 序列化为字符串
writer.close()
# print("image to tfrecord processed")
return
if __name__ == '__main__':
# training data creation
data_path = "../DATA/train"
tf_record_create(cwd=data_path,
classes=['airplane', 'ship', 'background'],
writer=tf.python_io.TFRecordWriter("xingtu_cls_4360train.tfrecords"))
print("creating processed")
tfrecords在读取时,根据存储时feature
字典依次读取其内内容。
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channel': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
整体代码如下:
"""
2018.7.20 --
Use this code to create the dataset for classification.
first write your own class_dic according to the folder.
then set the root path for all class folders in "data_path"
:parameter classes: the name of folder for different classes' images
:parameter writer: the filename of tfrecord
"""
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import scipy.misc as misc
import numpy as np
cwd = os.getcwd()
class_dic = {
"airplane": 0,
"ship": 1,
"background": 2
}
def tf_record_read_and_save(filepath):
filename_queue = tf.train.string_input_producer([filepath]) # 读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channel': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
}) # 将image数据和label取出来
h = tf.cast(features['height'], tf.int32)
w = tf.cast(features['width'], tf.int32)
c = tf.cast(features['channel'], tf.int32)
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [h, w, c])
label = tf.cast(features['label'], tf.int32)
label = tf.reshape(label, [1])
with tf.Session() as sess: # 开始一个会话
init_op = tf.global_variables_initializer()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1):
example, l = sess.run([image, label]) # 在会话中取出image和label
h, w, c = sess.run([h, w, c])
print("(h, w, c) = {}, {}, {}".format(h, w, c))
print("image's shape:", example.shape)
img = np.array(np.squeeze(example))
img = Image.fromarray(img, 'RGB' if example.shape[2] == 3 else 'L')
img.save(str(i) + '_Label_' + str(l) + '.jpg') # 存下图片
# print(example, l)
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
tf_record_read_and_save(filepath="xingtu_cls_4360train.tfrecords",)
注:tfrecords生成与读取检查的代码为cls_tfrecord_create.py
.
代码为/data_made/tfrecord_read_and_show/demo_tfrecord_read_c1.py
Two ways to read in the tfrecord dataset: 1. next_batch 2. dataset.next.
next batch:
Using tf.train.string_input_producer(), read_single_example_and_decode(), preprocess_image() and tf.train.batch() to create the (image, label) tensors.
dataset
Using tf.data.TFRecordDataset(), dataset = dataset.map(_parser) and dataset.’method’ to create the (image, label) tensors.
代码如下:
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
_HEIGHT = 256
_WIDTH = 256
_CHANNELS = 1
# preprocessing parameters
random_extend_ratio = 1.2
random_contrast_lower = 0.3
random_contrast_upper = 1.0
random_brightness_max_delta = 0.5
def preprocess_image(image, is_training):
if is_training:
image = tf.image.resize_images(images=image,
size=[tf.cast(_HEIGHT * random_extend_ratio, tf.int32),
tf.cast(_WIDTH * random_extend_ratio, tf.int32)])
image = tf.random_crop(image, [_HEIGHT, _WIDTH, _CHANNELS])
# flip
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
# adjust hue, contrast, saturation, bright(hue and saturation are not supported for one channel gray image)
image = tf.image.random_contrast(image, lower=random_contrast_lower, upper=random_contrast_upper)
image = tf.image.random_brightness(image, max_delta=random_brightness_max_delta)
else:
# according to the test, resize_images & resize_area have the same resize function which didn't appear the
# problem mentioned by the blog(), while resize_bicubic() has a side effect when align_corners is setted as True.
image = tf.image.resize_images(images=image,
size=[_HEIGHT, _WIDTH])
image = tf.image.per_image_standardization(image)
return image
def read_single_example_and_decode(filename_queue):
# reader = tf.TFRecordReader(options=tfrecord_options)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized=serialized_example,
features = {
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channel': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
}
)
img_height = tf.cast(features['height'], tf.int32)
img_width = tf.cast(features['width'], tf.int32)
img_channel = tf.cast(features['channel'], tf.int32)
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, shape=[img_height, img_width, img_channel])
label = tf.cast(features['label'], tf.int32)
return img, label
def next_batch(dataset_name, batch_size, is_training):
if dataset_name == "xingtu":
pattern = "../tfrecords/xingtu_cls_63test.tfrecords"
else:
raise ValueError("xingtu only")
print('tfrecord path is -->', os.path.abspath(pattern))
# filename_tensorlist = tf.train.match_filenames_once(pattern)
filename_queue = tf.train.string_input_producer([pattern])
image, label = read_single_example_and_decode(filename_queue)
image = preprocess_image(image, is_training)
img_batch, label_batch = tf.train.batch([image, label],
batch_size=batch_size,
capacity=32,
num_threads=4,
dynamic_pad=True)
return img_batch, label_batch
# return image, label
# obtain the mask for seg
def input_fn(filename, is_training, batch_size, shuffle_buffer, num_epochs=1):
##
if os.path.exists(filename):
pass
else:
raise ValueError("not such file exists")
def _parser(example_proto):
features = {
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channel': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features=features)
height = tf.cast(parsed_features['height'], tf.int32)
width = tf.cast(parsed_features['width'], tf.int32)
c = tf.cast(parsed_features['channel'], tf.int32)
image = tf.decode_raw(parsed_features['img_raw'], tf.uint8)
image = tf.reshape(image, [height, width, c])
image = preprocess_image(image, is_training)
label = tf.cast(parsed_features['label'], tf.int32)
return image, label
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.prefetch(buffer_size=batch_size)
if is_training:
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
dataset = dataset.repeat(num_epochs)
dataset = dataset.map(_parser)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
next_image, next_label = iterator.get_next()
return next_image, next_label
if __name__ == '__main__':
image, label = next_batch(dataset_name="xingtu", batch_size=1, is_training=True)
# image, label = input_fn("./data/one_image.tfrecords",
# is_training=False, batch_size=1, shuffle_buffer=1, num_epochs=1)
tf.summary.image("image", image)
summary_op = tf.summary.merge_all()
init = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
global_step = tf.train.get_or_create_global_step()
with tf.Session() as sess:
writer = tf.summary.FileWriter("./sar_summary", sess.graph)
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
img_ = sess.run(image)
img_show = np.array(np.squeeze(img_,), dtype=np.uint8)
plt.figure()
plt.imshow(img_show)
plt.show()
summary = sess.run(summary_op)
writer.add_summary(summary, 0)
coord.request_stop()
coord.join(threads)