制作自己的tfrecords数据集

最近想要制作自己的图像数据集,看了很多网上的教程,但是坑太多。。。。现在为了让大家少踩坑,附上全部代码,其中包括tfrecords格式数据集的制作与读取。

 

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf 
from PIL import Image
#制作数据自己的tfrecords数据集
cwd = os.getcwd()#这里的0,1,2代表的是数据存储的类别文件夹
classes={'0','1','2'} 
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index, name in enumerate(classes): 
class_path = cwd + "/" + name + "/" 
for img_name in os.listdir(class_path): 
img_path = class_path + img_name 
img = Image.open(img_path) 
img = img.resize((32, 32)) 
img_raw = img.tobytes() 
example = tf.train.Example(features=tf.train.Features(feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) writer.write(example.SerializeToString()) 
writer.close()
#显示tfrecords格式中的图片
filename_queue = tf.train.string_input_producer(["train.tfrecords"]) 
reader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue) 
features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string), }) 
image = tf.decode_raw(features['img_raw'], tf.uint8)image = tf.reshape(image, [32, 32, 3])
label = tf.cast(features['label'], tf.int32)with tf.Session() as sess:
 #init_op = tf.initialize_all_variables() 
init_op = tf.global_variables_initializer() 
sess.run(init_op) 
coord=tf.train.Coordinator() 
threads= tf.train.start_queue_runners(coord=coord) 
for i in range(17): ##这里的17要特别注意,17在这里表示的是我的数据集一共有17张图像 
example, l = sess.run([image,label])
img=Image.fromarray(example, 'RGB') 
img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg') print(example, l) 
coord.request_stop() 
coord.join(threads)

 

 

 

 

 

 

 


 

你可能感兴趣的:(DL)