tensorflow读取tfRecords文件,可批量读取,双输入

保存

import glob
import cv2
import numpy as np
import tensorflow as tf
import sys


# A function to Load images
def load_image(addr):
    # read an image and resize to (224, 224)
    # cv2 load images as BGR, convert it to RGB
    img = cv2.imread(addr)
    img = cv2.resize(img, (1242, 375), interpolation=cv2.INTER_CUBIC)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)
    return img


# Convert data to features
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


image_2_path = 'train/image_2/*.png'
image_3_path = 'train/image_3/*.png'
dis_path = 'train/dis/*.png'
image_2_addrs = glob.glob(image_2_path)
image_3_addrs = glob.glob(image_3_path)
dis_addrs = glob.glob(dis_path)

# Write data into a TFRecords file
train_filename = 'train.tfrecords'  # address to save the TFRecords file

# open the TFRecords file
writer = tf.python_io.TFRecordWriter(train_filename)

for i in range(len(image_2_addrs)):
    img2 = load_image(image_2_addrs[i])
    img3 = load_image(image_3_addrs[i])
    label = load_image(dis_addrs[i])

    # Create a feature
    feature = {'train/label': _bytes_feature(tf.compat.as_bytes(label.tostring())),
               'train/image_2': _bytes_feature(tf.compat.as_bytes(img2.tostring())),
               'train/image_3': _bytes_feature(tf.compat.as_bytes(img3.tostring()))}

    # Create an example protocol buffer
    example = tf.train.Example(features=tf.train.Features(feature=feature))

    # Serialize to string and write on the file
    writer.write(example.SerializeToString())

writer.close()
sys.stdout.flush()

读取

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

data_path = 'train.tfrecords'  # address to save the hdf5 file


def read_TFRecord(num_epochs, batch_size):
    feature = {'train/image_2': tf.FixedLenFeature([], tf.string), 'train/image_3': tf.FixedLenFeature([], tf.string),
               'train/label': tf.FixedLenFeature([], tf.string)}

    # Create a list of filenames and pass it to a queue num_epochs轮数
    filename_queue = tf.train.string_input_producer([data_path], num_epochs=num_epochs)

    # Define a reader and read the next record
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    # Decode the record read by the reader
    features = tf.parse_single_example(serialized_example, features=feature)

    # Convert the image data from string back to the numbers
    image2 = tf.decode_raw(features['train/image_2'], tf.float32)
    image3 = tf.decode_raw(features['train/image_3'], tf.float32)
    label = tf.decode_raw(features['train/label'], tf.float32)

    # Reshape image data into the original shape
    image2 = tf.reshape(image2, [375, 1242, 3])
    image3 = tf.reshape(image3, [375, 1242, 3])
    label = tf.reshape(label, [375, 1242, 3])

    # Creates batches by randomly shuffling tensors
    # images, labels= tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1,
    #                                       min_after_dequeue=10)
    images2, images3, labels = tf.train.shuffle_batch([image2, image3, label], batch_size=batch_size, capacity=3000,
                                                      num_threads=1,
                                                      min_after_dequeue=1000)
    return images2, images3, labels


with tf.Session() as sess:
    images_2, images_3, label_s = read_TFRecord(2, 5)
     # Initialize all global and local variables
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)

    # Create a coordinator and run all QueueRunner objects
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    try:
        while not coord.should_stop():
            img2, img3, lbl = sess.run([images_2, images_3, label_s])
            img2 = img2.astype(np.uint8)
            img3 = img3.astype(np.uint8)
            lbl = lbl.astype(np.uint8)

            for j in range(5):
                plt.figure('image_2')
                plt.subplot(2, 3, j + 1)
                plt.imshow(img2[j, ...])

                plt.figure('image_3')
                plt.subplot(2, 3, j + 1)
                plt.imshow(img3[j, ...])

                plt.figure('dis2')
                plt.subplot(2, 3, j + 1)
                plt.imshow(lbl[j, ...])
            plt.show()
    except tf.errors.OutOfRangeError:
        print('Done training for epochs')
    finally:
        # Stop the threads
        coord.request_stop()

    # Wait for threads to stop
    coord.join(threads)
    sess.close()

你可能感兴趣的:(tensorflow)