TensorFlow——将自己的图片数据转换为TFRecord

TFRecord是TensorFlow提供的一种存储数据的格式,可方便的存储数据的各种信息。

下面程序以猫狗图片为例

1. 写入数据

从cats和dogs文件夹中读取图片,resize为特定大小,然后存入TFRecord文件中。

import os 
import tensorflow as tf 
from PIL import Image

curr_path='./path/'
classes={'cats','dogs'} 
writer= tf.python_io.TFRecordWriter("path/cats_dogs.tfrecords")#输出TFRecord的路径

for index,name in enumerate(classes):#enumerate(): 将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标
    class_path=curr_path+name+'/'
    for img_name in os.listdir(class_path): #os.listdir: 返回指定的文件夹包含的文件或文件夹的名字的列表。这个列表以字母顺序
        img_path=class_path+img_name 
        img=Image.open(img_path)
        img=img.resize((200,200))
        img_raw=img.tobytes()#图像转换为Bytes

        #将一个样例转换为Example Protocol Buffer, 并将所有信息写入这个数据结构
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        })) 
        writer.write(example.SerializeToString())  #将一个example写入TFRecord文件

writer.close()

2. 读取数据

从cats_dogs.tfrecords中读取数据,解析为数组,再转换为Image存入指定文件夹

import tensorflow as tf 
from PIL import Image
import matplotlib.pyplot as plt

filename = './path/cats_dogs.tfrecords'

filename_queue = tf.train.string_input_producer([filename])# 创建一个队列来维护输入文件列表
reader = tf.TFRecordReader()#创建读取TFRecord文件的reader
_, serialized_example = reader.read(filename_queue)#读取样例

#属性解析:tf.FixedLenFeature返回一个Tensor
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                   })

images = tf.decode_raw(features['img_raw'], tf.uint8) #将字符串解析为对应的像素数组
images = tf.reshape(images, [200, 200, 3])  #reshape images
#images = tf.cast(images, tf.float32) * (1. / 255) - 0.5 #throw images tensor
labels = tf.cast(features['label'], tf.int32) #throw labels tensor


with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    for i in range(10):
        image, label = sess.run([images, labels])
        img=Image.fromarray(image, 'RGB')#将image数据从array转换为Image
        img.save('./path/image_test/'+str(i)+'_''label_'+str(label)+'.jpg')#save image
        plt.imshow(img)#显示解析出来的图片
        plt.show()
    
    coord.request_stop()
    coord.join(threads)

 

你可能感兴趣的:(tensorflow)