1.构建文件列表,使用cifar-10数据集
path='./data/cifar/'
file_names=os.listdir(path)
file_list=[os.join(path,file_name) for file_name in file_names if file_name.endswith('.bin')]
2.开启队列读取文件列表
queue_list=tf.train.string_input_producer(file_list)
3.构建阅读器读取数据
reader=tf.FixedLengthRecordReader(self.bytes)
key,value=reader.read(queue_list)
4.解析数据
label_iamge=tf.decode_row(value,tf.uint8)
5.将数据切成特征值和目标值
label = tf.slice(label_image, [0], [self.label_bytes])
image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
6.特征值进行形状改变
image_reshape=tf.reshape(image,[self.height, self.width, self.channel])
7.进行批处理
image_batch,label_batch=tf.train.batch([image,label]batch_size=10, num_threads=3, capacity=20)
8.开启会话进行训练
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord=coord)
print(sess.run([image_batch, label_batch]))
coord.request_stop()
coord.join(threads)
完整代码
import tensorflow as tf
import os
class CirarReader():
def __init__(self, filelsit):
self.file_list = filelsit
self.height = 32
self.width = 32
self.channel = 3
self.label_bytes = 1
self.image_bytes = self.width * self.height * self.channel
self.bytes = self.label_bytes + self.image_bytes
def read_decode_cifar(self):
queue_list = tf.train.string_input_producer(self.file_list)
reader = tf.FixedLengthRecordReader(self.bytes)
key, value = reader.read(queue_list)
# 解析
label_image = tf.decode_raw(value, tf.uint8)
# 将数据分割成标签数据和图片数据,特征值和目标值
label = tf.slice(label_image, [0], [self.label_bytes])
image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
# 特征数据形状的改变
image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
# 批处理
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=3, capacity=20)
return image_batch, label_batch
if __name__ == '__main__':
path = './data/cifar/'
file_names = os.listdir(path)
file_list = [os.path.join(path, file_name) for file_name in file_names if file_name.endswith('.bin')]
reader = CirarReader(file_list)
image_batch, label_batch = reader.read_decode_cifar()
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord=coord)
print(sess.run([image_batch, label_batch]))
coord.request_stop()
coord.join(threads)