tensorflow的tfrecord创建与读取

对flower_potos的图片建立tfrecord文件

这里的代码不需要调试了,只要把flower_potos解压到与程序即可,当然还要建立一个flower_photos_tfrecord文件。创建开始。。。。

 INPUT_DATA = 'flower_photos'
 OUT_DATA = 'flower_photos_tfrecord'
 READ_DATA = 'flower_photos_tfrecord'
 NUM_SHARDS = 2
 creat_tfcord(INPUT_DATA,OUT_DATA)#就是它了

所有的代码在下面:
因为需要把所有的图片统一到同一尺寸,用了这个image_data.resize((224, 224))

接下来就是读取了

这里有个大坑啊!!!!!!!!

example_batch1, label_batch1 =   read_tfrecord(READ_DATA, batch_size=100)
sess = tf.Session()
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())

就是上面:
example_batch1, label_batch1 = read_tfrecord(READ_DATA, batch_size=100)
这句必须在两个初始化函数前面,而且初始化函数必须写,不然.......

OutOfRangeError: RandomShuffleQueue '_72_shuffle_batch_6/random_shuffle_queue' is closed and has insufficient elements (requested 100, current size 0)

这个大坑我可是一行一行调出来的。。。。

贴上所有的代码:

# -*- coding: utf-8 -*-
"""
Created on Thu Aug 30 14:29:58 2018

@author: yanghe
"""

import tensorflow as tf
import os
import glob
import threading
from PIL import Image
import numpy as np

def get_labels_dirs(INPUT_DATA):
    sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
    is_root_dir = True
    dir_labels = {}
    for sub_dir in sub_dirs:
        if is_root_dir:
            is_root_dir = False
            continue

        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    
        file_list = []
        dir_name = os.path.basename(sub_dir)
        for extension in extensions:
            file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
            file_list.extend(glob.glob(file_glob))
        dir_labels[dir_name] = file_list
        if not file_list: continue
    return dir_labels
    
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))    

def _get_dataset_filename(dataset_dir, split_name, shard_id):
    output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, NUM_SHARDS)
    return os.path.join(dataset_dir, output_filename)
    
def creat_tfcord(INPUT_DATA,OUT_DATA):
    dir_labels = get_labels_dirs(INPUT_DATA)
    class_names_to_ids = dict(zip(dir_labels.keys(), range(len(dir_labels.keys()))))
    for dir_label in dir_labels:
        split_name = class_names_to_ids[dir_label]
        shard_id = 1
        filename = _get_dataset_filename(OUT_DATA, split_name, shard_id)
        writer = tf.python_io.TFRecordWriter(filename)
        file_list = dir_labels[dir_label]
        for file in file_list:
            image_data = Image.open(file)  
            #根据模型的结构resize
            image_data = image_data.resize((224, 224))
            #灰度化
            image_data = np.array(image_data.convert('L'))
            #将图片转化为bytes
            image_data = image_data.tobytes()   
#            image_data = open(file,'rb').read()
            example = tf.train.Example(features=tf.train.Features(feature={
                'label': _int64_feature(split_name),
                'image_raw': _bytes_feature(image_data)
            }))
            writer.write(example.SerializeToString())
        writer.close()
        print ("TFRecord文件已保存。")


def read_tfrecord(READ_DATA, batch_size):
    file =READ_DATA + "/image_*"
    files = tf.train.match_filenames_once(file)
    filename_queue = tf.train.string_input_producer(files, shuffle=True) 
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw':tf.FixedLenFeature([],tf.string),
            'label':tf.FixedLenFeature([],tf.int64)
        })
    
    decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)
    retyped_images = tf.cast(decoded_images, tf.float32)
    label = tf.cast(features['label'],tf.int32)
    example = tf.reshape(retyped_images, [224*224])
    min_after_dequeue = 3
    capacity = min_after_dequeue + 3 * batch_size
    example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size , capacity=capacity,min_after_dequeue=min_after_dequeue)
    
    return example_batch, label_batch
    
if __name__ == '__main__':
    INPUT_DATA = 'flower_photos'
    OUT_DATA = 'flower_photos_tfrecord'
    READ_DATA = 'flower_photos_tfrecord'
    NUM_SHARDS = 2
    #    creat_tfcord(INPUT_DATA,OUT_DATA)
    example_batch1, label_batch1 =   read_tfrecord(READ_DATA, batch_size=100)
    sess = tf.Session()
    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)
    for i in range(2):
        images, labels= sess.run([example_batch1, label_batch1])
        print(images.shape, labels)

    coord.request_stop()
    coord.join(threads)

你可能感兴趣的:(tensorflow的tfrecord创建与读取)