注意:这个地址是TensorFlow的数据读取机制,如果了解请跳过。
原博客地址:https://zhuanlan.zhihu.com/p/27238630
建议阅读博客:https://blog.csdn.net/pursuit_zhangyu/article/details/80607529
代码地址:https://github.com/hzy46/Deep-Learning-21-Examples/tree/master/chapter_2
我的上一篇文章,我采用自己的图片制作了数据集,现在我写一下读取自己制作的数据集。
数据集地址:
链接:https://pan.baidu.com/s/1aIHzKsxUb67sJZAFrGH1ZQ
提取码:lvjp
工程地址:
链接:https://pan.baidu.com/s/1XGAA6UQ0JByhvDYQ__my4g
提取码:dxpn
import numpy as np
import tensorflow as tf
batchSize = 15
num_epochs = 20
def tfRecordRead(fileNameQue, heigh, width, channels, n_class):
reader = tf.TFRecordReader()
# 创建一个队列来维护输入文件列表
# 从文件中读出一个Example
_, serialized_example = reader.read(fileNameQue)
# 用FixedLenFeature将读入的Example解析成tensor
features = tf.parse_single_example(
serialized_example,
features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
})
# 将字符串解析成图像对应的像素数组
image = tf.decode_raw(features['image'], tf.float32)
# image = tf.decode_raw(features["image"], tf.uint8)
image = tf.reshape(image, [heigh, width, channels])
# image = tf.cast(image, tf.float32) * (1 / 255.0)
labels = tf.cast(features['label'], tf.int64)
labels = tf.one_hot(labels, n_class)
return image, labels
def tfRecordBatchRead(filename, heigh, width, channels, n_class, batchSize):
fileNameQue = tf.train.string_input_producer([filename], shuffle=False, num_epochs=num_epochs)
image, labels = tfRecordRead(fileNameQue, heigh, width, channels, n_class) # fetch图像和label
min_after_dequeue = 1000
capacity = min_after_dequeue + 3 * batchSize
# 预取图像和label并随机打乱,组成batch,此时tensor rank发生了变化,多了一个batch大小的维度
imageBatch, labelBatch = tf.train.shuffle_batch([image, labels], batch_size=batchSize,
capacity=capacity, min_after_dequeue=min_after_dequeue)
return imageBatch, labelBatch
filename = r'./record\Imageoutput.tfrecords'
# filename = 'Imageoutput.tfrecords'
dataset = np.load('testData.npz')
x_test = dataset['test_X'][1:20]
y_test = dataset['test_Y'][1:20]
heigh, width, channels, n_class = dataset['height'], dataset['width'], dataset['channels'], dataset['n_class']
print(heigh, width, channels, n_class)
imageBatch, labelBatch = tfRecordBatchRead(filename, heigh, width, channels, n_class, batchSize)
# init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(num_epochs):
example, label = sess.run([imageBatch, labelBatch])
print(label)
coord.request_stop()
coord.join(threads)
上述代码把label输出了,运行结果如下:
注意事项:
TFRecord读取数据集的过程中比前面说的TensorFlow数据读取机制多了一步:从TFRecord文件中解析出数据
features = tf.parse_single_example(
serialized_example,
features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
})
# 将字符串解析成图像对应的像素数组
image = tf.decode_raw(features['image'], tf.float32)
# image = tf.decode_raw(features["image"], tf.uint8)
image = tf.reshape(image, [heigh, width, channels])
# image = tf.cast(image, tf.float32) * (1 / 255.0)
labels = tf.cast(features['label'], tf.int64)
labels = tf.one_hot(labels, n_class)