tensorflow 读取cifar_TensorFlow TFRecords 写入和读取

一.图片和图片标签写入TFRecords

1.创建文件存储器

writer = tf.io.TFRecordWriter('./data/tfrecords/cifar.tfrecords')

2.for循环将读取的数据存入导example存入TFRecords

#每个batch存十个图片数据

for i in range(10)

image=image_batch[i].eval().tostring()

lable=label_batch[i].eval()[0]

tf.train.Example(

features=tf.train.Features(

feature={ 'image':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),

"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))

}))

writer.write(example.SerializeToString())

writer.close()

二.读取TFRecords

1.构建文件队列

queue_list=tf.train.string_input_producer(['./data/tfrecords/cifar.tfrecords'])

2.构造阅读器

reader=tf.TFRecordReader()

key,value=reader.read(queue_list)

3.解析读取的example

features=tf.parse_single_example(value,features={

'image':tf.FixedLengthFeature([],tf.string(0)

'label':tf.FixedLengthFeature([],tf.int64)

})

4.解码内容

image=tf.decode_raw(features['image',tf.uint8])

label=tf.cast(features['label'],tf.float32)

5.固定图片形状

image_reshape=tf.reshape(image,[self.height, self.width, self.channel])

6.进行批处理

batch_image,label_batch=tf.train.batch([image_reshape,label],batch_size=10,num_threads=1,capacity=10)

完整的读取代码

import tensorflow as tf

import os

FLAGS = tf.app.flags.FLAGS

cifar_tfrecords = tf.app.flags.DEFINE_string('cifar_tfrecords', './data/tfrecords/cifar.tfrecords', 'tfrecords目录')

class CirarReader():

def __init__(self, filelsit):

self.file_list = filelsit

self.height = 32

self.width = 32

self.channel = 3

self.label_bytes = 1

self.image_bytes = self.width * self.height * self.channel

self.bytes = self.label_bytes + self.image_bytes

def read_decode_cifar(self):

queue_list = tf.train.string_input_producer(self.file_list)

reader = tf.FixedLengthRecordReader(self.bytes)

key, value = reader.read(queue_list)

# 解析

label_image = tf.decode_raw(value, tf.uint8)

# 将数据分割成标签数据和图片数据,特征值和目标值

label = tf.slice(label_image, [0], [self.label_bytes])

image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])

# 特征数据形状的改变

image_reshape = tf.reshape(image, [self.height, self.width, self.channel])

# 批处理

image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=3, capacity=20)

return image_batch, label_batch

def write_to_tfrecords(self, image_batch, label_batch):

# 简历文件存储器

writer = tf.io.TFRecordWriter('./data/tfrecords/cifar.tfrecords')

# 循环写入每一个样本每张图片都要构造example协议

for i in range(10):

# 获取图片的值

image = image_batch[1].eval().tostring()

# 获取标签的值

label = label_batch[i].eval()[0]

# 创建example从存储图片和标签

example = tf.train.Example(features=tf.train.Features(feature={

"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),

"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))

}))

writer.write(example.SerializeToString())

writer.close()

def read_from_tfrecords(self):

# 构建队列

queue_list = tf.train.string_input_producer(['./data/tfrecords/cifar.tfrecords'])

# 构建阅读器读取数据

reader = tf.TFRecordReader()

key, value = reader.read(queue_list)

# 解析数据

features = tf.parse_single_example(value, features={

"image": tf.FixedLenFeature([], tf.string),

"label": tf.FixedLenFeature([], tf.int64)

})

# 解码内容,

image = tf.decode_raw(features['image'], tf.uint8)

label = tf.cast(features['label'], tf.float32)

# 固定图片形状

image_shape = tf.reshape(image, [self.height, self.width, self.channel])

# 进行批处理

batch_image, batch_label = tf.train.batch([image_shape, label], batch_size=10, num_threads=1, capacity=10)

return batch_image, batch_label

if __name__ == '__main__':

path = './data/cifar/'

file_names = os.listdir(path)

file_list = [os.path.join(path, file_name) for file_name in file_names if file_name.endswith('.bin')]

reader = CirarReader(file_list)

# image_batch, label_batch = reader.read_decode_cifar()

batch_image, batch_label = reader.read_from_tfrecords()

with tf.Session() as sess:

coord = tf.train.Coordinator()

threads = tf.train.start_queue_runners(sess, coord=coord)

# 写入tfRecords文件

# print("-------start----------")

# reader.write_to_tfrecords(image_batch, label_batch)

# print("-------end------------")

print(sess.run([batch_image, batch_label]))

coord.request_stop()

coord.join(threads)

你可能感兴趣的:(tensorflow,读取cifar)