TensorFlow(三)之多线程

本博文参考TensorFlow实战Google深度学习框架(郑泽宇,顾思宇),仅用作学习

一、TFRecord输入数据格式

TFRecord是tensorflow中存储数据的统一格式。可以统一不同的原始数据格式,并更加有效地管理不同的属性。TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。是一种可将图像的数据和标签放在一起的二进制文件,能节省内存,在TensorFlow中快速读取存储。

tf.train.Example的定义如下:

message Example{
    Features Features=1;
};

message Features{
    map feature=1;
};

message Feature{
    oneof kind{
       ByteList bytes_list=1;
       FloatList float_list=2;
       Int64List int64_list=3;
    }
};

tf.train.Example中包含了一个从属性名称到取值的字典。属性名称为字符串,属性的取值可以为字符串(ByteList),实数列表(FloatList)或整数列表(Int64List)。

从文件中读取数据一般分为:把样本数据写入TFRecords二进制文件,再从队列中读取。

1、生成TFRecord文件

需要将数据填到tf.train.Example的协议缓存区(Protocol Buffer)中,将协议缓存区序列化为一个字符串,通过tf.python_io.TFRecordWriter写入TFRecord文件中。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

mnist=input_data.read_data_sets("/path/to/mnist/data",dtype=tf.unit8,one_hot=True)
#训练数据
image=mnist.train.images
#训练数据所对应的的正确答案,可以作为一个属性保存在TFRecorde中
labels=mnist.train.labels
#训练数据的图像分辨率,可以作为Example中的一个属性
pixels=image.shape[1]
num_examples=mnist.train.num_examples

#生成整数型的属性
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
#生成字符串型的属性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

filename="/path/to/outpput.tfrecords"
writer=tf.python_io.TFRecordWriter(filename)
for i in range(num_examples):
    #将图像矩阵转化为一个字符串
    image_raw=image[i].tostring()
    #将一个样例转化为ExamplecProtocol Buffer,并将所有信息写入这个数据结构
    example=tf.train.Example(features=tf.train.Features(feature={
        'pixels': _int64_feature(pixels),
        'labels': _int64_feature(np.argmax(labels[i])),
        'image_raw':_bytes_feature(image_raw)
    }))
    writer.write(example.SerializeToString())
writer.close()
2、从队列中读取

首先创建张量,从二进制文件中读取一个样本

创建张量,从二进制文件中随机读取一个mini-batch

把每一批张量传入网络作为输入节点

import tensorflow as tf

# 读取文件。
files = tf.train.match_filenames_once("/path/to/output.tfrecords")
filename_queue = tf.train.string_input_producer(files, shuffle=False)
reader = tf.TFRecordReader()
#从文件中读出一个样例(也可以永别的函数读取多个样例)
_,serialized_example = reader.read(filename_queue)

# 解析读取的样例(也可以用别的函数解析多个样例)
features = tf.parse_single_example(
      serialized_example,
      features={
        'image_raw':FixedLenFeature([],tf.string),
        'labels': FixedLenFeature([],tf.int64),
        'pixels': FixedLenFeature([],tf.int64)
      })

images=tf.decode_raw(features['image_raw'],tf.uint8)
labels=tf.cast(features['labels'],tf.int32)
pixels=tf.cast(features['pixels'],tf.int32)

sess=tf.Session()
#启动多线程处理数据
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)

for i in range(10):
    image,label,pixel=sess.run([images,labels,pixels])

二、队列

队列也是图中的一个节点

队列主要有FIFOQueue和RandomShuffleQueue。tensorflow可以利用队列来实现多线程输入数据处理。

FIFOQueue创建一个先入先出队列。RandomShuffleQueue创建一个随机队列,在异步训练中很重要

API:tf.FIFOQueue   tf.RandomShuffleQueue

import tensorflow as tf
#创建先进先出队列,指定队列最多可以保存两个元素
q=tf.FIFOQueue(2,"int32")
#初始化队列的元素
init=q.enqueue_many(([0,10],))
#将队列的第一个元素出列,保存在x里面
x=q.dequeue()
y=x+1
#将y加入队列
q_inc=q.enqueue([y])

with tf.Session() as sess:
    #运行初始化队列的操作
    init.run()
    #执行出队列,将队列的元素加1,加1后的元素进队的循环操作5次
    for _ in range(5):
       v,_=sess.run([x,q_inc])
       print(v)

三、队列管理器&线程和协调器

tensorflow提供了tf.Cooradinator和tf.QueueRunner两个类来完成多线程协同的功能。

QueueRunner:队列管理器

coordinator:协调器,协调线程间的关系可以视为一种信号量,用来做同步。主要用于协同多个线程一起停止,并提供了should_stop、request_stop和join三个函数

启动线程之前,需要先声明一个tf.Cooradinator类,并将这个类传入每一个创建的线程中。

启动的线程需要一直查询tf.Cooradinator类中提供的should_stop函数,返回值为True时,线程也退出。

每个线程都可以通过request_stop函数来通知其他线程退出。当线程调用request_stop时,should_stop的返回值将被设置为True,这样其他的线程就可以同时终止了。

import tensorflow as tf
import numpy as np
import threading
import time

#线程中运行的程序,每隔一秒判断是否需要停止并打印自己的ID
def MyLoop(coord,worker_id):
    #使用tf.Coordinator类提供的协同工具判断当前线程是否需要停止
    while not coord.should_stop():
        #随机停止所有的线程
        if np.random.rand()<0.1:
            print("stoping from id: %d\n" % worker_id)
            #调用coord.request_stop()函数来通知其他线程停止
            coord.request_stop()
        else:
            #打印当前线程ID
            print("working on id: %d\n" % worker_id)
        #暂停1秒
        time.sleep(1)
#声明一个协调器来协同多个线程
coord=tf.train.Coordinator()

#创建5个线程
threads=[
    threading.Thread(target=MyLoop,args=(coord,i,)) for  i in range(5)]
#启动所有线程
for t in threads:
    t.start()
#join操作等待所有线程关闭,这一函数才能返回
coord.join(threads)

四、输入数据处理框架

                                        TensorFlow(三)之多线程_第1张图片
import tensorflow as tf
files=tf.train.match_filenames_once("/path/to/output.tfrecords")
filename_queue = tf.train.string_input_producer(files, shuffle=False)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
      serialized_example,
      features={
          'image': tf.FixedLenFeature([], tf.string),
          'label': tf.FixedLenFeature([], tf.int64),
          'height': tf.FixedLenFeature([], tf.int64),
          'width': tf.FixedLenFeature([], tf.int64),
          'channels': tf.FixedLenFeature([], tf.int64),
      })
image,label=features['image'],features['label']
height,width=features['height'],features['width']
channels=features['channels']

decoded_image=tf.decode_raw(image,tf.uint8)
decoded_image.set_shape([height,width,channels])

image_size=299
distorted_image=preprocess_for_train(decoded_image,image_size,image_size,None)

min_after_dequeue=10000
batch_size=100
capacity=min_after_dequeue+3*batch_size
image_batch,label_batch=tf.train.shuffle_batch(
    [distorted_image,label],batch_size=batch_size,
    capacity=capacity,min_after_dequeue=min_after_dequeue)

logit=inference(image_batch)
loss=calc_loss(logit,label_batch)
train_step=tf.train.GradientDescentOptimizer(learning_rate)\
    .minimize(loss)

with tf.Session() as sess:
        tf.global_variables_initializer().run()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(TRAINING_ROUDNS):
            sess.run(train_step)
        coord.request_stop()
        coord.join(threads)


你可能感兴趣的:(深度学习)