保存
import glob
import cv2
import numpy as np
import tensorflow as tf
import sys
# A function to Load images
def load_image(addr):
# read an image and resize to (224, 224)
# cv2 load images as BGR, convert it to RGB
img = cv2.imread(addr)
img = cv2.resize(img, (1242, 375), interpolation=cv2.INTER_CUBIC)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32)
return img
# Convert data to features
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
image_2_path = 'train/image_2/*.png'
image_3_path = 'train/image_3/*.png'
dis_path = 'train/dis/*.png'
image_2_addrs = glob.glob(image_2_path)
image_3_addrs = glob.glob(image_3_path)
dis_addrs = glob.glob(dis_path)
# Write data into a TFRecords file
train_filename = 'train.tfrecords' # address to save the TFRecords file
# open the TFRecords file
writer = tf.python_io.TFRecordWriter(train_filename)
for i in range(len(image_2_addrs)):
img2 = load_image(image_2_addrs[i])
img3 = load_image(image_3_addrs[i])
label = load_image(dis_addrs[i])
# Create a feature
feature = {'train/label': _bytes_feature(tf.compat.as_bytes(label.tostring())),
'train/image_2': _bytes_feature(tf.compat.as_bytes(img2.tostring())),
'train/image_3': _bytes_feature(tf.compat.as_bytes(img3.tostring()))}
# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()
读取
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
data_path = 'train.tfrecords' # address to save the hdf5 file
def read_TFRecord(num_epochs, batch_size):
feature = {'train/image_2': tf.FixedLenFeature([], tf.string), 'train/image_3': tf.FixedLenFeature([], tf.string),
'train/label': tf.FixedLenFeature([], tf.string)}
# Create a list of filenames and pass it to a queue num_epochs轮数
filename_queue = tf.train.string_input_producer([data_path], num_epochs=num_epochs)
# Define a reader and read the next record
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# Decode the record read by the reader
features = tf.parse_single_example(serialized_example, features=feature)
# Convert the image data from string back to the numbers
image2 = tf.decode_raw(features['train/image_2'], tf.float32)
image3 = tf.decode_raw(features['train/image_3'], tf.float32)
label = tf.decode_raw(features['train/label'], tf.float32)
# Reshape image data into the original shape
image2 = tf.reshape(image2, [375, 1242, 3])
image3 = tf.reshape(image3, [375, 1242, 3])
label = tf.reshape(label, [375, 1242, 3])
# Creates batches by randomly shuffling tensors
# images, labels= tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1,
# min_after_dequeue=10)
images2, images3, labels = tf.train.shuffle_batch([image2, image3, label], batch_size=batch_size, capacity=3000,
num_threads=1,
min_after_dequeue=1000)
return images2, images3, labels
with tf.Session() as sess:
images_2, images_3, label_s = read_TFRecord(2, 5)
# Initialize all global and local variables
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
# Create a coordinator and run all QueueRunner objects
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop():
img2, img3, lbl = sess.run([images_2, images_3, label_s])
img2 = img2.astype(np.uint8)
img3 = img3.astype(np.uint8)
lbl = lbl.astype(np.uint8)
for j in range(5):
plt.figure('image_2')
plt.subplot(2, 3, j + 1)
plt.imshow(img2[j, ...])
plt.figure('image_3')
plt.subplot(2, 3, j + 1)
plt.imshow(img3[j, ...])
plt.figure('dis2')
plt.subplot(2, 3, j + 1)
plt.imshow(lbl[j, ...])
plt.show()
except tf.errors.OutOfRangeError:
print('Done training for epochs')
finally:
# Stop the threads
coord.request_stop()
# Wait for threads to stop
coord.join(threads)
sess.close()