tfrecord格式数据的生成与读取

tfrecord格式数据的生成与读取

tfrecord格式数据简介

tfrecord格式数据的生成

代码如下所示:

import tensorflow as tf
import os
import cv2
import numpy as np
from random import shuffle

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  
shuffle_data = True
image_path = './image/'
save_path = './data.tfrecords'
image_list = [] #用于存储数据文件的地址
label_list = [] #用于存储标签

for root,dirs,files in os.walk(image_path):
    for image_name in files:
        image_list.append(image_path+image_name)
        label_list.append(image_name[0])
if shuffle_data:  #是否随机打乱
    c = list(zip(image_list, label_list))  
    shuffle(c)  
    image_list, label_list = zip(*c)

#图片的读入的函数,可以使用opencv完成,并将数据进行相关的处理
def loade_image_label(image_list,label_list):
    image = cv2.imread(image_list)
    image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_CUBIC) 
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    #image = image / 255.0 
    #image = np.array(image)
    image = image.astype(np.float32)
    label = int(label_list)
    return image,label
writer = tf.python_io.TFRecordWriter(save_path)  
for i in range(len(image_list)):
    image,label = loade_image_label(image_list[i],label_list[i])
    feature = {'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                'image':tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes(image.tostring())]))}
    example = tf.train.Example(features=tf.train.Features(feature=feature))  
    writer.write(example.SerializeToString()) 
writer.close()

tfrecord格式数据的读取

import tensorflow as tf
import os
import numpy as np  
import matplotlib.pyplot as plt
import cv2
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  
data_path = './data.tfrecords'
feature = {  
        'image': tf.FixedLenFeature([], tf.string),  
        'label': tf.FixedLenFeature([], tf.int64)  
}  

filename_queue = tf.train.string_input_producer([data_path], num_epochs=5)  

reader = tf.TFRecordReader()  
_, serialized_example = reader.read(filename_queue)  
features = tf.parse_single_example(serialized_example, features=feature)  
image = tf.decode_raw(features['image'], tf.float32)    
label = tf.cast(features['label'], tf.int64)  
image = tf.reshape(image, [224, 224, 3])  
images, labels = tf.train.shuffle_batch([image, label], batch_size=5, capacity=25, min_after_dequeue=5)    
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:  
    sess.run(init) 
    coord = tf.train.Coordinator()  
    threads = tf.train.start_queue_runners(sess,coord=coord) 
    for i in range(5):
        #######################
        x,y = sess.run([images, labels])
        ####数据的相关的处理操作#####
        ########################

你可能感兴趣的:(TensorFlow)