对flower_potos的图片建立tfrecord文件
这里的代码不需要调试了,只要把flower_potos解压到与程序即可,当然还要建立一个flower_photos_tfrecord文件。创建开始。。。。
INPUT_DATA = 'flower_photos'
OUT_DATA = 'flower_photos_tfrecord'
READ_DATA = 'flower_photos_tfrecord'
NUM_SHARDS = 2
creat_tfcord(INPUT_DATA,OUT_DATA)#就是它了
所有的代码在下面:
因为需要把所有的图片统一到同一尺寸,用了这个image_data.resize((224, 224))
接下来就是读取了
这里有个大坑啊!!!!!!!!
example_batch1, label_batch1 = read_tfrecord(READ_DATA, batch_size=100)
sess = tf.Session()
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
就是上面:
example_batch1, label_batch1 = read_tfrecord(READ_DATA, batch_size=100)
这句必须在两个初始化函数前面,而且初始化函数必须写,不然.......
OutOfRangeError: RandomShuffleQueue '_72_shuffle_batch_6/random_shuffle_queue' is closed and has insufficient elements (requested 100, current size 0)
这个大坑我可是一行一行调出来的。。。。
贴上所有的代码:
# -*- coding: utf-8 -*-
"""
Created on Thu Aug 30 14:29:58 2018
@author: yanghe
"""
import tensorflow as tf
import os
import glob
import threading
from PIL import Image
import numpy as np
def get_labels_dirs(INPUT_DATA):
sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
is_root_dir = True
dir_labels = {}
for sub_dir in sub_dirs:
if is_root_dir:
is_root_dir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
file_list = []
dir_name = os.path.basename(sub_dir)
for extension in extensions:
file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
file_list.extend(glob.glob(file_glob))
dir_labels[dir_name] = file_list
if not file_list: continue
return dir_labels
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, NUM_SHARDS)
return os.path.join(dataset_dir, output_filename)
def creat_tfcord(INPUT_DATA,OUT_DATA):
dir_labels = get_labels_dirs(INPUT_DATA)
class_names_to_ids = dict(zip(dir_labels.keys(), range(len(dir_labels.keys()))))
for dir_label in dir_labels:
split_name = class_names_to_ids[dir_label]
shard_id = 1
filename = _get_dataset_filename(OUT_DATA, split_name, shard_id)
writer = tf.python_io.TFRecordWriter(filename)
file_list = dir_labels[dir_label]
for file in file_list:
image_data = Image.open(file)
#根据模型的结构resize
image_data = image_data.resize((224, 224))
#灰度化
image_data = np.array(image_data.convert('L'))
#将图片转化为bytes
image_data = image_data.tobytes()
# image_data = open(file,'rb').read()
example = tf.train.Example(features=tf.train.Features(feature={
'label': _int64_feature(split_name),
'image_raw': _bytes_feature(image_data)
}))
writer.write(example.SerializeToString())
writer.close()
print ("TFRecord文件已保存。")
def read_tfrecord(READ_DATA, batch_size):
file =READ_DATA + "/image_*"
files = tf.train.match_filenames_once(file)
filename_queue = tf.train.string_input_producer(files, shuffle=True)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'label':tf.FixedLenFeature([],tf.int64)
})
decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)
retyped_images = tf.cast(decoded_images, tf.float32)
label = tf.cast(features['label'],tf.int32)
example = tf.reshape(retyped_images, [224*224])
min_after_dequeue = 3
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size , capacity=capacity,min_after_dequeue=min_after_dequeue)
return example_batch, label_batch
if __name__ == '__main__':
INPUT_DATA = 'flower_photos'
OUT_DATA = 'flower_photos_tfrecord'
READ_DATA = 'flower_photos_tfrecord'
NUM_SHARDS = 2
# creat_tfcord(INPUT_DATA,OUT_DATA)
example_batch1, label_batch1 = read_tfrecord(READ_DATA, batch_size=100)
sess = tf.Session()
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
for i in range(2):
images, labels= sess.run([example_batch1, label_batch1])
print(images.shape, labels)
coord.request_stop()
coord.join(threads)