tensorflow TFrecords数据的读取
1、中间进行batch+shuffle操作
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
def normalize(image,label):
image = tf.cast(image,tf.float32) / 255
return image, label
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
'label_raw': tf.FixedLenFeature([], tf.string)
})
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
image = tf.decode_raw(features['image_raw'], tf.uint8)
label = tf.decode_raw(features['label_raw'], tf.uint8)
label = tf.cast(label, tf.int32)
image = tf.reshape(image, [height, width, 3])
label = tf.reshape(label, [height, width, 1])
image = tf.cast(image, tf.float32)
return image, label
train_path = "/media/cv/DataB/xj/pycharm_project/tensorflow1.14/data/train.tfrecords"
val_path = "/media/cv/DataB/xj/pycharm_project/tensorflow1.14/data/val.tfrecords"
train_nums = 0
val_nums = 0
print(">>>>>>>>>>>>>>>>>>>>>>>>")
for record in tf.python_io.tf_record_iterator(train_path):
train_nums += 1
print("train_nums: ", train_nums)
for record in tf.python_io.tf_record_iterator(val_path):
val_nums += 1
print("val_nums: ", val_nums)
img, label = read_and_decode(train_path)
img, label = normalize(img,label)
img = tf.image.resize_images(img, [384, 1024])
label = tf.image.resize_images(label, [384, 1024])
img_batch, label_batch = tf.train.shuffle_batch( [img, label],
batch_size=4,
num_threads=4,
capacity=1000,
min_after_dequeue=900)
with tf.Session() as sess:
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(3):
image, label = sess.run([img_batch, label_batch])
print("img: ", image.shape)
plt.imshow(image[i, :, :, :])
plt.show()
coord.request_stop()
coord.join(threads)
2、直接读取
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
def normalize(image,label):
image = tf.cast(image,tf.float32) / 255
return image, label
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
'label_raw': tf.FixedLenFeature([], tf.string)
})
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
image = tf.decode_raw(features['image_raw'], tf.uint8)
label = tf.decode_raw(features['label_raw'], tf.uint8)
label = tf.cast(label, tf.int32)
image = tf.reshape(image, [height, width, 3])
label = tf.reshape(label, [height, width, 1])
image = tf.cast(image, tf.float32)
return image, label
train_path = "/media/cv/DataB/xj/pycharm_project/tensorflow1.14/data/train.tfrecords"
val_path = "/media/cv/DataB/xj/pycharm_project/tensorflow1.14/data/val.tfrecords"
train_nums = 0
val_nums = 0
print(">>>>>>>>>>>>>>>>>>>>>>>>")
for record in tf.python_io.tf_record_iterator(train_path):
train_nums += 1
print("train_nums: ", train_nums)
for record in tf.python_io.tf_record_iterator(val_path):
val_nums += 1
print("val_nums: ", val_nums)
img, label = read_and_decode(train_path)
img, label = normalize(img,label)
with tf.Session() as sess:
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(3):
print(img.eval().shape)
plt.imshow(img.eval())
plt.show()
coord.request_stop()
coord.join(threads)