tensorflow TFrecords数据的读取

tensorflow TFrecords数据的读取

1、中间进行batch+shuffle操作

import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt

def normalize(image,label):
    image = tf.cast(image,tf.float32) / 255
    return image, label


def read_and_decode(filename):
    #根据文件名生成一个队列
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                            'height': tf.FixedLenFeature([], tf.int64),
                                            'width': tf.FixedLenFeature([], tf.int64),
                                            'image_raw': tf.FixedLenFeature([], tf.string),
                                            'label_raw': tf.FixedLenFeature([], tf.string)
                                       })

    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)

    image = tf.decode_raw(features['image_raw'], tf.uint8)
    label = tf.decode_raw(features['label_raw'], tf.uint8)
    label = tf.cast(label, tf.int32)
    image = tf.reshape(image, [height, width, 3])
    label = tf.reshape(label, [height, width, 1])
    image = tf.cast(image, tf.float32)
    # label = tf.cast(features['label'], tf.int32)

    return image, label



train_path = "/media/cv/DataB/xj/pycharm_project/tensorflow1.14/data/train.tfrecords"
val_path = "/media/cv/DataB/xj/pycharm_project/tensorflow1.14/data/val.tfrecords"

# raw_dataset = tf.TFRecordReader(train_path)
train_nums = 0
val_nums = 0
print(">>>>>>>>>>>>>>>>>>>>>>>>")

for record in tf.python_io.tf_record_iterator(train_path):
    train_nums += 1
print("train_nums: ", train_nums)
for record in tf.python_io.tf_record_iterator(val_path):
    val_nums += 1
print("val_nums: ", val_nums)


img, label = read_and_decode(train_path)
img, label = normalize(img,label)
img = tf.image.resize_images(img, [384, 1024]) # 这里必须重新设置尺寸,否则报错:shuffle_batch时提示All shapes must be fully defined #1
label = tf.image.resize_images(label, [384, 1024])
img_batch, label_batch = tf.train.shuffle_batch( [img, label],
                                                  batch_size=4,
                                                  num_threads=4,
                                                  capacity=1000,
                                                  min_after_dequeue=900)

with tf.Session() as sess:
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord = tf.train.Coordinator()  # 创建一个协调器,管理线程
    threads = tf.train.start_queue_runners(coord=coord)  # 启动QueueRunner, 此时文件名队列已经进队

    # print (img.eval().shape)
    # print (label.shape)
    for i in range(3):
        image, label = sess.run([img_batch, label_batch])
        print("img: ", image.shape)
        # print("label: ", l)

        plt.imshow(image[i, :, :, :])
        plt.show()
        # plt.imshow(img.eval())
        # plt.show()
    coord.request_stop()
    coord.join(threads)

2、直接读取

import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt

def normalize(image,label):
    image = tf.cast(image,tf.float32) / 255
    return image, label


def read_and_decode(filename):
    #根据文件名生成一个队列
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                            'height': tf.FixedLenFeature([], tf.int64),
                                            'width': tf.FixedLenFeature([], tf.int64),
                                            'image_raw': tf.FixedLenFeature([], tf.string),
                                            'label_raw': tf.FixedLenFeature([], tf.string)
                                       })

    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)

    image = tf.decode_raw(features['image_raw'], tf.uint8)
    label = tf.decode_raw(features['label_raw'], tf.uint8)
    label = tf.cast(label, tf.int32)
    image = tf.reshape(image, [height, width, 3])
    label = tf.reshape(label, [height, width, 1])
    image = tf.cast(image, tf.float32)
    # label = tf.cast(features['label'], tf.int32)

    return image, label



train_path = "/media/cv/DataB/xj/pycharm_project/tensorflow1.14/data/train.tfrecords"
val_path = "/media/cv/DataB/xj/pycharm_project/tensorflow1.14/data/val.tfrecords"

# raw_dataset = tf.TFRecordReader(train_path)
train_nums = 0
val_nums = 0
print(">>>>>>>>>>>>>>>>>>>>>>>>")

for record in tf.python_io.tf_record_iterator(train_path):
    train_nums += 1
print("train_nums: ", train_nums)
for record in tf.python_io.tf_record_iterator(val_path):
    val_nums += 1
print("val_nums: ", val_nums)


img, label = read_and_decode(train_path)
img, label = normalize(img,label)

with tf.Session() as sess:
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord = tf.train.Coordinator()  # 创建一个协调器,管理线程
    threads = tf.train.start_queue_runners(coord=coord)  # 启动QueueRunner, 此时文件名队列已经进队

    for i in range(3):
        print(img.eval().shape)  #这里img会自动更新,eval()代表执行结果,跟sess.run()类似
        plt.imshow(img.eval())
        plt.show()
    coord.request_stop()
    coord.join(threads)

你可能感兴趣的:(tfrecords,tensorflow,tfrecords,读取)