【Tensorflow学习笔记---制作并读取TFRecord数据集】

【Tensorflow学习笔记---制作并读取TFRecord数据集】_第1张图片

参考链接:https://blog.csdn.net/m0_37407756/article/details/80684883

代码:

import os
import tensorflow as tf
from PIL import Image  
import matplotlib.pyplot as plt
import numpy as np
 
cwd='C:\\Users\\Administrator\\Desktop\\0430\\'
classes={'1','2'} #
class_map = {}    # 文件名与label关系,保存便于查看
writer= tf.python_io.TFRecordWriter("train.tfrecords") 




#制作部分
for index, name in enumerate(classes):
    class_path = cwd + name + '\\'
    class_map[index] =  name
    for img_name in os.listdir(class_path):
        img_path = class_path + img_name  # 每一个图片的地址
        img = Image.open(img_path)
        img = img.resize((128, 128))
        img_raw = img.tobytes()  # 将图片转化为二进制格式
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))  # example对象对label和image数据进行封装
        writer.write(example.SerializeToString())  # 序列化为字符串

writer.close()
 
 
txtfile = open('class_map.txt','w+')
for key in class_map.keys():
    txtfile.writelines(str(key)+":"+class_map[key]+"\n")
txtfile.close()


#读取部分
filename_queue = tf.train.string_input_producer(["train.tfrecords"]) #读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)   #返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                   })  #取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [128, 128, 3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess: #开始一个会话
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord=tf.train.Coordinator() #创建一个协调器,管理线程
    threads= tf.train.start_queue_runners(coord=coord)#启动QueueRunner, 此时文件名队列已经进队
    for i in range(20):
        example, l = sess.run([image,label])#在会话中取出image和label
        img=Image.fromarray(example, 'RGB')#这里Image是之前提到的
        img.save( cwd +str(i)+'_''Label_'+str(l)+'.jpg')#存下图片
        print(example, l)
    coord.request_stop()
    coord.join(threads)

 

你可能感兴趣的:(tensorflow)