Python TensorFlow,TFRecord文件类型,读取TFRecord文件,TFRecordReader()

Python TensorFlow,TFRecord文件类型,读取TFRecord文件,TFRecordReader()_第1张图片

Python TensorFlow,TFRecord文件类型,读取TFRecord文件,TFRecordReader()_第2张图片


demo.py(TFRecord,读取TFRecord,TFRecordReader()):

import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'  # 设置警告级别


# 读取tfrecords文件。

# 找到数据文件,放入列表   路径+名字->列表当中
file_names = os.listdir("./mydata/")
print(file_names)  # ['dog.tfrecords']
# 拼接路径和文件名
filename_list = [os.path.join("./mydata/", file) for file in file_names]

# 1、构造文件名队列
file_queue = tf.train.string_input_producer(filename_list)

# 2、构造文件阅读器,读取example协议块
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)  # value是序列化后的example协议块(一个样本对应一个协议块)

# 3、解析example协议块。 解析成字典类型(键值对形式)的样本信息
features = tf.parse_single_example(value, features={
    "image": tf.FixedLenFeature([], tf.string),  # 要与存储的key和数据类型保持对应。
    "label": tf.FixedLenFeature([], tf.int64)
})

# 4、解码内容,解码成数值类型。 如果读取的内容格式是string类型,就需要解码, 如果是int64,float32不需要解码
image = tf.decode_raw(features["image"], tf.uint8)  # string类型解码成uint8类型。

# 固定图片(样本)的形状 (批处理需要数据形状固定)
image_reshape = tf.reshape(image, [32, 32, 3])  # 3表示图片3个通道

label = tf.cast(features["label"], tf.int32)  # 转换类型
print(image_reshape)  # Tensor("Reshape:0", shape=(32, 32, 3), dtype=uint8)
print(label)  # Tensor("Cast:0", shape=(), dtype=int32)

# 进行批处理
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)


# 开启会话运行结果
with tf.Session() as sess:
    # 创建一个线程协调器
    coord = tf.train.Coordinator()

    # 开启读文件的子线程
    threads = tf.train.start_queue_runners(sess, coord=coord)

    # 打印读取的内容
    print(sess.run(label_batch))
    '''
    [5 6 0 9 4 3 1 2 9 7]
    '''
    print(sess.run(image_batch))
    '''
    [[[[178 178 178]
       [178 179 179]
       [179 180 180]
       ...
       [176 175 173]
       [171 168 166]
       [163 159 155]]]]
    '''

    # 结束子线程
    coord.request_stop()
    # 等待子线程结束
    coord.join(threads)

 

 

你可能感兴趣的:(Python+,机器学习)