【Tensorflow】TFRecord数据的写入和读取经验

问题描述:利用tensorflow进行神经网络训练,当数据集较小或者内存足够大时,通常的做法是将全部数据集加载到内存里,然后再将数据集分批feed给网络进行训练(一般配合yield使用效果更佳)。但是,当数据集大到内存不足以全部加载进来的时候,必须寻找新的加载数据的方法。
解决办法:
可以尝试使用tensorflow提供的队列queue,训练时从文件中分批读取数据。这里选择tensorflwo内定的标准格式TFRecord.

TFRecord简介

TFRecord是一种二进制文件,可以支持多线程数据读取,可以通过batch_size和epoch参数来控制训练时单次batch的大小和样本迭代次数,同时能更好地利用内存和方便数据的复制和移动,所以是tensorflow进行大规模深度学习训练的首选。

TFRecord文件的制作

每个训练样本在TFRecord中称为example,tensorflow使用tf.train.Example协议来存储训练样本,每个example本质上是一个字典dict类型,用来存储一个训练样本的多个feature信息(如input、label、mask等等),且每个feature信息必须是tensorflow预定义好的类型(ByteList,FloatList以及Int64List中的一种)。最后,example通过SerializeToString()方法将样例序列化成字符串存储,tensorflow通过TFRecordWriter将这些序列化之后的字符串存成tfrecord形式。

import tensorflow as tf
import glob  #glob为python自带,无须另外安装
import scipy.io as sio
import numpy as np

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def load_data(path):
    data = sio.loadmat(path)
    data = data['Img']
    if data.shape[0] != 256:
        data = data[8:264]
    data_real = np.real(data)
    data_imag = np.imag(data)
    data = np.concatenate((data_real,data_imag),axis=-1)
    return data

folder = 'D:\Data\MRbrain\Data*' 
folders = glob.glob(folder_path)  # 表示MRbrain目录下所有Data开头的文件夹绝对路径,list列表类型返回
writer = tf.python_io.TFRecordWriter('train.tfrecords')
for i in range(len(folders)):
    data_path = folders[i] + '\\*.mat' # 注意windows和linux系统下反斜线的差异
    data_addr = glob.glob(data_path) # 表示当前Data文件夹下所有.mat格式的文件,list列表类型返回
    for j in range(len(data_addr)):
        img = load_data(data_addr[j]) # load当前的.mat数据,如果是图片格式,修改对应的load_data函数即可
        feature = {'train/label': _bytes_feature(img.tobytes())} # 这里可以在dict中添加多个key-value对,不同的数据序列化为相应的类型,如图片转换为bytes,宽高、类别等信息可转换为int64或float类型,视具体情况而定
        example = tf.train.Example(features=tf.train.Features(feature=feature))
        writer.write(example.SerializeToString())
    print('Data{}/{} Done'.format(i,len(folders)))
writer.close()

TFRecord文件的读取

首先要解析数据

def read_and_decode(filename, batch_size):
    filename_queue = tf.train.string_input_producer([filename]) # 创建文件队列
    reader = tf.TFRecordReader()
    _,serialized_example = reader.read(filename_queue) #从队列中读取相应的example
    feature = {'train/label':tf.FixedLenFeature([],tf.string)}
    features = tf.parse_single_example(serialized_example, features=feature) # 从example中按照存储时的编码格式进行解析
    img = tf.decode_raw(features['train/label'], tf.float64) # 从解析结果中读取存储的信息,并进行类型转换,对图像数据需用decode_raw进行解码,然后进行reshape
    img = tf.reshape(img, shape=[256, 256, 24])
    # 分批从队列中读取,shuffle_batch会进行打乱,可以选择多个线程同时读取队列数据,capacity表示队列容量,min_after_dequeue表示读取一次后队列至少需要剩下的样本数
    img_batch = tf.train.shuffle_batch([img], batch_size=batch_size, num_threads=64, capacity=30, min_after_dequeue=10)
    return img_batch

解析TFRecord文件的函数定义好后,开始启用队列读取数据

import tensorflow as tf
import numpy as np

filename = 'train.tfrecords'
img_batch = read_and_decode(filename,4)
with tf.Session() as sess:
    coord = tf.train.Coordinator() # coord用来协调各个线程
    threads = tf.train.start_queue_runners(coord=coord) # 开启各个队列的线程
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)
    i = 0
    while not coord.should_stop():
        image = sess.run(img_batch) # 返回的image即为一个batch的数据,这里其维度为[4,256,256,24]
        print('{}th batch load.'.format(i))
        i += 1
    coord.request_stop()
    coord.join(threads)

使用队列的方法读取TFRecord数据的优势在于,既解决了数据量过大以致内存不足的问题,又能启用多线程分批读取数据,同时读取数据和进行前向反向计算,一定程度上保证了效率。

你可能感兴趣的:(tensorflow)