TFRecord的保存和加载

编辑于2019/02/23,转载请注明

好久之前写的,tf版本貌似是1.7,有错误请轻喷。

TF有两种方式,静态图和动态图(eager模式)。下面讲到的tfrecord数据的制作和使用,有的是要用graph+session.run(),有的必须要用erger模式,有的都可以。

TFRecord的保存

介绍以create_tfrecords.py为例,先写一些辅助的函数:

import os
import cv2
import tensorflow as tf
import threading
import math
import argparse

"""
#数据保存的格式
The purpose of this script is to create a set of .tfrecords files
from a folder of images and a folder of annotations.
Annotations: rect(x1\ty1\tx2\ty2)#imagepath#kpts(x1\ty1\tx2\ty2\t...x68\ty68\t)\tpose(y\tp\tr)
Images :*.jpg, not face image

Example of use:
python create_tfrecords.py \
    --image_dir=/home/gpu2/hdd/dan/WIDER/val/images/ \
    --annotations_dir=/home/gpu2/hdd/dan/WIDER/val/annotations/ \
    --output=data/train_shards/ \
    --num_shards=100
"""

def make_args():
    parser = argparse.ArgumentParser()
    # parser.add_argument('-i', '--image_dir', default='E:\\Dataset\\FDDB\\val\\images\\', type=str)
    # parser.add_argument('-a', '--annotations_dir',default='E:\\Dataset\\FDDB\\val\\annotations\\', type=str)
    parser.add_argument('-o', '--output', default='/home/yfy1127yfy/pyProject/create_tfrecords_kpts_poses/results', type=str)
    parser.add_argument('-s', '--num_shards', type=int, default=1)
    return parser.parse_args()

# 把value转成二进制,如果不是list,转成list
def _bytes_feature(value):
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

# 把value转成float,如果不是list,转成list
def _float_list_feature(value):
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

# 把value转成int,如果不是list,转成list
def _int64_list_feature(value):
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

# 把一个大的list读进内存后,拆成n个小的list,主要用于多线程
def div_list(list,n):
    num_eachlist = int(len(list)/n)
    num_last = len(list)%n
    divlists = []
    # 1 ~ n-1 has the same num len(list)/n, the last one has len(list)/n+num_last
    # if list[103], n = 5, then elements of divlists has indexes: 0~19, 20~39, 40~59, 60~79, 80~102
    for i in range(n-1):
        divlists.append(list[i*num_eachlist:((i+1)*num_eachlist)])
    divlists.append(list[(n-1)*num_eachlist:])
    return divlists

将image和label写入TFRecord。 这部分代码其实没有静态动态的区分。文件后缀名不关键,只要是TFRecord的write生成的就行

# 该函数,设置num_shards=n,可以将一个输入的lines,打印成n个tfrecord
def create_tfrecords(output_dir, output_name, lines, thread_id = 0, resize_w = 128, resize_h = 128, num_kpts = 68, num_shards = 1):
    shard_id = 0 # 如果list太大,可以把一个list保存成m个tfrecord,这是序号
    num_examples_written = 0 # 写文件次数计数器
    # file_annotation = open(annotation_path, 'r')
    # lines = file_annotation.readlines()
    shard_size = math.ceil(len(lines) / num_shards)

    for line_id, line in enumerate(lines):
        words = line.split('#')
        rect = words[0].strip().split('\t')
        rect = [int(r) for r in rect]
        image_path = words[1].strip()
        kpts = words[2].strip().split('\t')[:-3]
        pose = words[2].strip().split('\t')[-3:]
        feature = {}

        # convert image and face box
        try:
            image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        except Exception as e:
            print(e)
            print(image_path + '   #read image falied!')
            continue
        source_h, source_w = image.shape
        if source_h == 0 or source_w == 0:
            print('image read error, w==0 or h==0!')
            continue
        box_h = rect[3] - rect[1]
        box_w = rect[2] - rect[0]
        if box_h == 0 or box_w == 0:
            print('face box error, boxw==0 or boxh==0!')
            continue
            # rect[0] = 1#x
            # rect[1] = 1#y
            # rect[2] = image.shape[1]-1#w-1
            # rect[3] = image.shape[0]-1#h-1
        face = image[rect[1]:rect[3], rect[0]:rect[2]]
        gray_face = cv2.resize(face, (resize_w, resize_h))
        # cv2.imshow('gray_face', gray_face)
        # cv2.waitKey(0)
        rate_h = float(resize_h) / float(box_h)
        rate_w = float(resize_w) / float(box_w)
        image_bytes = gray_face.tobytes()
        feature['image'] = _bytes_feature(image_bytes)
        # feature['facebox'] = _int64_list_feature(rect)

        #convert kpts to 
        kpts_float = []
        if len(kpts) != num_kpts*2:
            print('num of kpts error!')
            continue
        for i in range(num_kpts):
            kpt_x = int(kpts[2 * i])
            kpt_y = int(kpts[2 * i +1])
            kpts_float.append(float((kpt_x - rect[0]) * rate_w/resize_w))
            kpts_float.append(float((kpt_y - rect[1]) * rate_h/resize_h))
        feature['keypoints'] = _float_list_feature(kpts_float)

        # convert pos
        pose_float = []
        for p in pose:
            pose_float.append((float(p)+90)/180)
        feature['pose'] = _float_list_feature(pose_float)

        # convert feature to example
        tf_features = tf.train.Features(feature=feature)
        tf_example = tf.train.Example(features=tf_features)
        tf_serialized = tf_example.SerializeToString()

        # write
        # 首次,初始化writer
        if num_examples_written == 0:
            shard_path = os.path.join(output_dir, output_name+'_T%02d_NO%02d.tfrecords' % (thread_id, shard_id))
            writer = tf.python_io.TFRecordWriter(shard_path)
        writer.write(tf_serialized)
        num_examples_written += 1
        # 写满50000个文件打印提示下
        if num_examples_written % 50000 == 0:
            print('thread_id: %02d    writing the %d example, total %d examples'%(thread_id, shard_id*shard_size+num_examples_written,len(lines)))
        # if num_examples_written==2:
        #     break
        # shard_size代表一个tfrecord数据个数,此处if用于保存多文件时,重启writer,更改文件名
        if num_examples_written == shard_size:
            shard_id += 1
            num_examples_written = 0
            writer.close()

    print('thread_id: %02d   write over!' % (thread_id))

使用多线程,将大的list分割成n个小的list,并行保存为tfrecord。

def main():
    ARGS = make_args()
    # annotations_dir = ARGS.annotations_dir
    output_dir = ARGS.output
    num_shards = ARGS.num_shards
    # change  3 variables
    annotation_path = '/home/yfy1127yfy/Dataset/dataForKpt/fromjiakun/pose_0to30_result.txt'
    output_name = 'pose_0to30'
    num_shards = 8 # 线程数
    file_annotation = open(annotation_path, 'r')
    lines = file_annotation.readlines()
    # random.shuffle(lines)# 将整个大list洗牌,但其实没有必要,读取tfrecord时本来就会shuffle的
    # div a list into 5 lists
    divlists = div_list(lines,num_shards)
    # multi thread ,create tfrecords
    threads = []
    # 设置daemon=false,那么子线程相当于后台线程,当主线程结束,子线程不会结束
    # daemon设为true,用join来等待确保安全
    for i, line in enumerate(divlists):
        threads.append(threading.Thread(target=create_tfrecords, args=(output_dir,output_name,line,i), daemon=True))

    for t in threads:
        t.start()
    # 让主线程等子线程
    for t in threads:
        t.join()
    # create_tfrecords(output_dir, divlists)
    print('all work done!')


if __name__ == '__main__':
    main()

从TFRecord中读取image和labels

分两种,一种用tf.TFRecordReader(),这种方法不可用于动态模式(Eager Execution),相对tf.data方法,不太方便,而且效率貌似略差 两种方法都需要使用到解析example,辅助函数如下:

import cv2
import tensorflow as tf
import os
import glob
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.data.python.ops import batching

def from_parse_example(tf_record_serialized):
    features = {}
    # 事先定义好要接收的数据类型,关键字,shape
    features['image'] = tf.FixedLenFeature(shape=[],dtype=tf.string)
    # 这两个这么写'shape=[]',出错了,指名了维数就不会出错
    features['pose'] = tf.FixedLenFeature(shape=[3,1], dtype=tf.float32)
    features['keypoints'] = tf.FixedLenFeature(shape=[136,1], dtype=tf.float32)
    # 解析
    tf_record_features = tf.parse_single_example(serialized=tf_record_serialized, features=features)
    # 将图片从二进制解码到tf.uint8
    tf_record_image =tf.decode_raw(tf_record_features['image'], tf.uint8)
    # reshape到图片该有的h,w,c
    tf_record_image=tf.reshape(tensor=tf_record_image,shape=[128,128,1])
    tf_record_kpts=tf.cast(tf_record_features['keypoints'], dtype=tf.float32)
    tf_record_pose=tf.cast(tf_record_features['pose'],dtype=tf.float32)
    # image_batch, kpts_batch, pose_batch = tf.train.batch([tf_record_image, tf_record_kpts, tf_record_pose], batch_size=5, num_threads=1, capacity=5)
    return tf_record_image, tf_record_kpts, tf_record_pose

tf.TFRecordReader

该方法读取大量数据时,会卡住,怀疑是把所有数据都一口气读进内存,所以只适合读小数据

############################## use TFRecordReader
def use_TFRecordReader(tfrecord_path):
    tf_record_filename_queue = tf.train.string_input_producer([tfrecord_path])
    tf_record_reader = tf.TFRecordReader()
    _, tf_record_serialized = tf_record_reader.read(tf_record_filename_queue)
    return from_parse_example(tf_record_serialized)


def main_use_TFRecordReader():
    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = '9'
    config.gpu_options.allow_growth = True
    tfrecord_path = '/home/yfy1127yfy/pyProject/create_tfrecords_kpts_poses/results/shard-00000000.tfrecords'
    image, kpts, pose = use_TFRecordReader(tfrecord_path)
    with tf.Session(config=config) as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)
        for i in range(20):
            final_image,final_kpts,final_pose = sess.run([image, kpts, pose])
            for j in range(68):
                cv2.circle(final_image,(int(final_kpts[0+2*j]),int(final_kpts[1+2*j])),1,(255))
            print(final_kpts,final_pose)
            cv2.imshow('image', final_image)
            cv2.waitKey(0)
        coord.request_stop()
        coord.join(threads)
##############################

tf.data.TFRecord,推荐

过程:

1)从cycle个tfrecord中交错读取(parallel_interleave方法),读取时是无序的,设定batch_size,shuffle,repeat,得到的是dataset

2)使用map+batch,读取dataset内的数据并进行解析,然后根据解析所得,制作对应iterator,使用iterator.get_next()就可以得到feature了

######################## use tf.data.TFRecordDataset
# input param:
#      file_names : 所有要读取的tf文件(list),leisi 'xxx/*.tfrecords'
#      cycles     : 同时读取的文件数
class tfrecord_dataset_iterator(object):
    def __init__(self, file_names, cycles, batch_sizes, reshape_w, reshape_h, channels, decode_func):
        self.file_names = glob.glob(file_names)
        #        self.file_names = (file_names)
        self.cycles = cycles
        self.batch_sizes = batch_sizes
        self.reshape_w = reshape_w
        self.reshape_h = reshape_h
        self.channels = channels
        self.decode_func = decode_func
        self.epochs = 1

    # 图片迭代器-从tfrecord格式文件获取图片数据
    def imgs_iterator(self):
        # 同时读取cycle个tfrecord的数据集
        proto = tf.data.TFRecordDataset.list_files(self.file_names)
        # parallel_interleave可以从cycle_length个文件中一次交叉去数据,
        # 取回的数据是否打乱,每个文件取多少等,看其他参数
        proto = proto.apply(interleave_ops.parallel_interleave(
            tf.data.TFRecordDataset, cycle_length=self.cycles, sloppy=True))
        # 我感觉像缓冲,prefetch似乎能加速,而且不会限制shuffle的范围
        proto = proto.prefetch(buffer_size=self.batch_sizes)
         # 基本步骤,shuffle+repeat+batch
        proto = proto.shuffle(buffer_size=100000)
        proto = proto.repeat(count=self.epochs)  # todo

        # map + batch
        # 按照tensorflow官网介绍 map_and_batch 速度快于  map  + batch
        proto = proto.apply(batching.map_and_batch(
            map_func=self.decode_func,
            batch_size=self.batch_sizes,
            num_parallel_batches=None))  # TODO:
        images_iterator = proto.make_one_shot_iterator()
        return images_iterator


def main_use_TFRecordDataset():
    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = '9'
    config.gpu_options.allow_growth = True
    batch_size = 5
    # tfrecord_path = '/home/yfy1127yfy/pyProject/create_tfrecords_kpts_poses/results/shard-00000000.tfrecords'
    with tf.Session(config=config) as sess:
        iter_tfrecord = tfrecord_dataset_iterator(
            file_names='/home/yfy1127yfy/pyProject/create_tfrecords_kpts_poses/results/pose_60to90*.tfrecords'
            , cycles=1
            , batch_sizes=batch_size
            , reshape_w=128
            , reshape_h=128
            , channels=1
            , decode_func=from_parse_example)
        iter_image = iter_tfrecord.imgs_iterator()
        image, kpts, pose = iter_image.get_next()
        for i in range(30):
            list_image,list_kpts,list_pose = sess.run([image, kpts, pose])
            for batch_id in range(batch_size):
                final_image = list_image[batch_id]
                final_kpts = list_kpts[batch_id]
                final_pose = list_pose[batch_id]
                for j in range(68):
                    cv2.circle(final_image, (int(128*final_kpts[0 + 2 * j]), int(128*final_kpts[1 + 2 * j])), 1, (255))
                for p in final_pose:
                    p = int(p*180-90)
                print(final_pose)
                # cv2.imshow('image', final_image)
                # cv2.waitKey(0)

附官网解释:

tf.contrib.data.parallel_interleave

tf.contrib.data.parallel_interleave( map_func, cycle_length, block_length=1, sloppy=False, buffer_output_elements=None, prefetch_input_elements=None )

Defined in tensorflow/contrib/data/python/ops/interleave_ops.py.

A parallel version of the Dataset.interleave() transformation.

parallel_interleave() maps map_func across its input to produce nested datasets, and outputs their elements interleaved. Unlike tf.data.Dataset.interleave, it gets elements from cycle_length nested datasets in parallel, which increases the throughput, especially in the presence of stragglers. Furthermore, the sloppy argument can be used to improve performance, by relaxing the requirement that the outputs are produced in a deterministic order, and allowing the implementation to skip over nested datasets whose elements are not readily available when requested.

Example usage:

# Preprocess 4 files concurrently.

filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")

dataset = filenames.apply( tf.contrib.data.parallel_interleave(
                                lambda filename: tf.data.TFRecordDataset(filename),
                                cycle_length=4))

WARNING: If sloppy is True, the order of produced elements is not deterministic.

Args:

  • map_func: A function mapping a nested structure of tensors to a Dataset.
  • cycle_length: The number of input Datasets to interleave from in parallel.
  • block_length: The number of consecutive elements to pull from an input Dataset before advancing to the next input Dataset.
  • sloppy: If false, elements are produced in deterministic order. Otherwise, the implementation is allowed, for the sake of expediency, to produce elements in a non-deterministic order.
  • buffer_output_elements: The number of elements each iterator being interleaved should buffer (similar to the .prefetch() transformation for each interleaved iterator).
  • prefetch_input_elements: The number of input elements to transform to iterators before they are needed for interleaving.

Returns:

Dataset transformation function, which can be passed to tf.data.Dataset.apply.

你可能感兴趣的:(Tensorflow)