Tensorflow笔记(七)

数据读取

  • TFRecord格式
    • TFRecord格式介绍
    • TFRecord格式转换示例代码
    • 编码阶段
    • 解码阶段
  • 队列
    • 数据队列
    • 文件队列
    • 使用多线程处理输入的数据
  • 组织数据batch

TFRecord格式

TFRecord格式介绍

TFRecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。

TFRecord格式转换示例代码

编码阶段

import tensorflow as tf
import numpy as np
from PIL import Image

data_dir = "/home/dataset/train/"
save_dir = "/home/tfrecord/train.tfrecords"

def  _int64_feature(value):
	return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def  _bytes_feature(value):
	return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

writer = tf.python_io.TFRecordWriter(save_dir)

for img in os.listdir(data_dir):
	img_path = data_dir + img
	img_raw = Image.open(img_path)
	img_raw = img_raw.resize((28,28))
	img_raw = img_raw.tostring()
	
	# tf.train.Example来定义我们要填入的数据格式,然后使用tf.python_io.TFRecordWriter来写入
	feature = {
		"label": _int64_feature(0),
		"img_raw": _bytes_feature(img_raw)
	}
	
	features = tf.train.Features(feature=feature)
	
	# 定义一个Example,将相关信息写入到这个数据结构
	example = tf.train.Example(features=features)
	
	# 将一个Example写入到TFRecord文件
    writer.write(example.SerializeToString()) 
writer.close()     

解码阶段

def read_and_decode(save_dir):
	filename_queue = tf.train.string_input_producer([save_dir])
	
	# 创建一个TFRecord类的实例
	reader = tf.TFRecordReader()
	
	_, serialized_example = reader.read(filename_queue)
	
	# 解析读入的一个record
    features = tf.parse_single_example(
        serialized_example,
        features={
	        "label": tf.FixedLenFeature([], tf.int64), 
	        "img_raw": tf.FixedLenFeature([], tf.string)
        })
        
    # decode_raw()用于将字符串解析成图像对应的像素数组
    img = tf.decode_raw(features["img_raw"], tf.uint8) 
    img = tf.reshape(img,[28,28,1])
    img = tf.cast(img, tf.float32) 
    label = tf.cast(features["label"], tf.int32)
    return img, label

在解析features时,我们使用了FixedLenFeature类,这个类会将解析的结果转换为一个Tensor。
由于图像预处理会拖慢整个训练过程,为了使得对更多、更大的图像进行预处理不会成为神经网络模型训练速度的瓶颈,Tensorflow提供了队列加多线程处理输入数据的方法。

队列

数据队列

  • FIFIQueue:先进先出队列;
  • RandomShuffleQueue:会打乱队列中元素的顺序,每次出队列的元素都是随机选择的。在训练神经网络时可能有随机抽取训练数据的需求,因此RandomShuffleQueue满足了这个需求。

修改队列状态的函数:enqueue_many()初始化队列、dequeue()出队、enqueue()入队。

import tensorflow as tf

// 元素数和数据类型
Queue = tf.FIFOQueue(2,"int32")
//队列初始化
queue_init = Queue.enqueue_many(([10,100]))
//出队
a = Queue.dequeue()
b = a + 10
//入队
Queue_en = Queue.enqueue([b])

with tf.Session() as sess:
	queue_init.run()
	for i in range(5):
		//都是op:a、Queue_en
		v,_ = sess.run([a,Queue_en]) 
		//打印队首元素
		print(v)

文件队列

将文件组织成队列需要先使用Tensorflow的train.match_filenames_once()函数来获取符合一个正则表达式的所有文件。这个函数会返回一个文件列表,然后将得到的文件列表送入train.string_input_producer()函数,这个函数会使用初始化时提供的文件列表创建一个输入文件队列,这个文件列表也可以不使用train.match_filenames_onece()函数返回的值而是直接给函数的参数赋予一个文件列表。输入文件队列中的元素就是文件列表中的所有文件,这个队列可作为文件读取函数(如TFRecordReader类的read()函数)的参数。
每次调用文件读取函数时,train.string_input_producer()函数会先判断当前是否已有打开的文件可读,如果没有已打开的文件或者打开的文件已经读完,这个函数会从输入队列中出队一个文件并从这个文件开始读取数据。

import tensorflow as tf

//获取符合正则表达式的所有文件形成一个文件列表
files = tf.train.match_filenames_once("./data_tfrecords-*")
//shuffle=True表示随机出队队列中的一个元素
file_queue = tf.train.string_input_producer(files,shuffle=False)
reader = tf.TFRecordReader()
_,serialized_example = reader.read(file_queue)

使用多线程处理输入的数据

Tensorflow的Session()对象是支持多线程的,因此多个线程可以很方便的在同一个会话下对同一个队列并行的执行操作。Tensorflow提供了两个类帮助实现多线程:

  • train.Coordinator类:偏重于管理线程
  • train.QueueRunner类:创建线程。

在train.add_queue_runner()函数中如果没有指定自己的集合,那么这些线程会被加入到计算图默认的GraphKeys.QUEUE_RUNNERS集合。train.start_queue_runners()函数会默认启动GraphKeys.QUEUE_RUNNERS集合中所有的QueueRunner。

import tensorflow as tf

queue = tf.FIFOQueue(100,'float')
//每次入队10个随机数
enqueue = queue.enqueue([tf.random_normal([10])])

//使用QueueRunner创建10个线程进行队列的入队操作
qr = tf.train.QueueRunner(queue,[enqueue] * 10)
//加入计算图集合
tf.train.add_queue_runner(qr)

//定义出队操作
out_op = queue.dequeue()

with tf.Session() as sess:
	//使用Coordinator类来协同启动线程
	coordinator = tf.train.Coordinator()

	//启动所有线程
	threads = tf.train.start_queue_runners(sess=sess,coord=coordinator)

	for i in range(10):
		print(sess.run(out_op))
	coordinator.request_stop()
	coordinator.join(threads)

组织数据batch

Tensorflow提供了train.batch()函数和train.shuffle.batch()函数来将数据组织成batch。这两个函数都会生成一个队列,队列的入队操作是产生单个样例的方法,而每次出队得到的是一个batch的样例。

//batch(tensors,batch_size,num_threads,capacity)
//capacity:用于组合成batch的队列中最多可以缓存的样例个数,即队列的最大容量
image_batch,label_batch = tf.train.batch([images,labels],batch_size=batch_size,capacity=capacity)

//shuffle_batch(tensors,batch_size,capacity,min_after_dequeue,num_threads)
//min_after_dequeue:表示限制出队时队列中元素的最少个数。不足min_after_dequeue时,出队操作将等待更多的样例入队才会完成
image_batch,label_batch = tf.train.shuffle_batch([images,labels],batch_size=batch_size,capacity=capacity,100)

你可能感兴趣的:(Tensorflow)