Tensorflow读取大数据集的方法,tf.train.string_input_producer()和tf.train.slice_input_producer()

Tensorflow一共提供了3种读取数据的方法:
供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据,比如说用PIL和numpy处理数据然后喂入神经网络。
从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据,这就是这篇文将要讲的内容。
预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。
对于大的数据集很难用numpy数组保存,所以这里介绍一下Tensorflow读取很大数据集的方法:string_input_producer()和slice_input_producer()。

他们两者区别可以简单理解为:string_input_producer每次放出一个文件名。slice_input_producer可以既可以同时放出文件名和它对应的label,也可以只放出一个文件名。而在实际应用代码的时候也只是读取文件的方式不一样,其他大致相同。

string_input_producer加载图片的reader是reader = tf.WholeFileReader() key,value = reader.read(path_queue)其中key是文件名,value是byte类型的文件流二进制,一般需要解码(decode)一下才能变成数组,然后进行reshape操作。

slice_input_producer加载图片的reader使用tf.read_file(filename)直接读取。记得图片需要解码和resize成数组,才可以放入内存队列file_queue中等待调用。
先放几个很棒的相关讲解:
csdn—tensorflow中协调器 tf.train.Coordinator 和入队线程启动器tf.train.start_queue_runners
csdn—Tensorflow两种数据读取方法应用、对比及注意事项
极客学院—三种数据读取
知乎—十图详解tensorflow数据读取机制(附代码)
很棒的项目:
Github—Chinese-Character-Recognition
知乎—TensorFlow与中文手写汉字识别----想飞的石头
首先使用string_input_producer()来读取三张如下的图片,并转化成数组的格式。
Tensorflow读取大数据集的方法,tf.train.string_input_producer()和tf.train.slice_input_producer()_第1张图片
Tensorflow读取大数据集的方法,tf.train.string_input_producer()和tf.train.slice_input_producer()_第2张图片
Tensorflow读取大数据集的方法,tf.train.string_input_producer()和tf.train.slice_input_producer()_第3张图片
代码如下:

'''
created on January 5 12:21 2018

@author:lhy
'''
import tensorflow as tf
path_list=['A.png','B.png','C.png']
img_path=tf.convert_to_tensor(path_list,dtype=tf.string)#将list转化张量tensor

image=tf.train.string_input_producer(img_path,num_epochs=1)#放入文件名队列中,epoch是1

def load_img(path_queue):
    #创建一个队列读取器,然后解码成数组,与slice的不同之处,重要!!!!!!!!!
    reader=tf.WholeFileReader()
    key,value=reader.read(path_queue)

    img=tf.image.convert_image_dtype(tf.image.decode_png(value,channels=3),tf.float32)#将图片decode成3通道的数组
    img=tf.image.resize_images(img,size=(224,224))
    return img

img=load_img(image)
print(img.shape)
#可以看出string进行处理的时候只处理了图片本身,对标签并没有处理。将图片放入内存队列,因为abtch_size=1,所以一次放入一张供读取。但是系统还是“停滞”状态。
image_batch=tf.train.batch([img],batch_size=1)

with tf.Session() as sess:
    tf.local_variables_initializer().run()
    tf.global_variables_initializer().run()
    coord=tf.train.Coordinator()
    #tf.train.start_queue_runners()函数才会启动填充队列的线程,系统不再“停滞”,此后计算单元就可以拿到数据并进行计算
    thread=tf.train.start_queue_runners(sess=sess,coord=coord)
    try:
        while not coord.should_stop():
            imgs=sess.run(image_batch)
            print(imgs.shape)
    #当文件队列读到末尾的时候,抛出异常
    except tf.errors.OutOfRangeError:
        print('done')
    finally:
        coord.request_stop()#将读取文件的线程关闭
    coord.join(thread)#将读取文件的线程加入到主线程中(虽然说已经关闭过)

运行结果如下:

(224, 224, 3)
(1, 224, 224, 3)
(1, 224, 224, 3)
(1, 224, 224, 3)
done

可以看到,三张图片被依次读了出来。
接下来使用slice_input_producer()来试试:

'''
created on January 5 13:08 2018

@author:lhy
'''
import tensorflow as tf

path_list=['A.png','B.png','C.png']
#加入了标签,在使用的时候可以直接对应标签取出数据
label=[0,1,2]
#转换成张量tensor类型
img_path=tf.convert_to_tensor(path_list,dtype=tf.string)
label=tf.convert_to_tensor(label,dtype=tf.int32)

#返回了一个包含路径和标签的列表,并将文件名和对应的标签放入文件名对列中,等待系统调用
image=tf.train.slice_input_producer([img_path,label],shuffle=True,num_epochs=1)#shuffle=Flase表示不打乱,当为True的时候打乱顺序放入文件名队列
labels=image[1]

def load_image(path_queue):
    #读取文件,这点与string_input_producer不一样!!!!!
    file_contents=tf.read_file(image[0])
    img=tf.image.convert_image_dtype(tf.image.decode_png(file_contents,channels=3),tf.float32)

    img=tf.image.resize_images(img,size=(228,228))
    return img

img=load_image(image)
print(img.shape)
#设置one_hot编码,并将labels规定为3种,在前向传播的时候默认会将结果的shape变为batch_size*3,从而达到分类的情况,这一步在使用标签的时候很重要
labels=tf.one_hot(labels,3)
img_batch,label_batch=tf.train.batch([img,labels],batch_size=1)

with tf.Session() as sess:
    #initializer for num_epochs
    tf.local_variables_initializer().run()
    coord=tf.train.Coordinator()
    thread=tf.train.start_queue_runners(sess=sess,coord=coord)
    try:
        while not coord.should_stop():
            imgs,label=sess.run([img_batch,label_batch])
            print(imgs.shape)
            print(label)
    except tf.errors.OutOfRangeError:
        print('Done')
    finally:
        coord.request_stop()
    coord.join(thread)


运行结果:

(228, 228, 3)
(1, 228, 228, 3)
[[0. 1. 0.]]
(1, 228, 228, 3)
[[1. 0. 0.]]
(1, 228, 228, 3)
[[0. 0. 1.]]
Done

可以看出,每次它取出了一个图片与它对应的标签,因为shuffle=True,所以随机取出的,当设置shuffle=False的时候会按照顺序取出,这种读取方法十分适合百万数据量级别的图片数据集。

你可能感兴趣的:(Tensorflow读取大数据集的方法,tf.train.string_input_producer()和tf.train.slice_input_producer())