tfrecords读取文件

以前做科研论文的时候, 所使用的音频数据比较少, 所以都是直接读进内存中在feeding给placeholder。现在在做一些偏工程的项目时就发现远远不行了,feeding训练速度远远提不上来。所以这两周都在为训练提速而折磨。在此记录下来尝试的方式。
tensorflow推荐使用tfrecords来存储数据, 这样能加快数据的读取。

def convert_to_tfrecord(loader):
    ''' modefy batch_size=1 in './conf/train_ce.conf' before convert to tfrecord data format '''
    def write_tfrecords(queue, i):
        start_time = time.time()
        while queue.empty():
            if time.time()-start_time > 600:               #超时队列中还没有数据该进程就退出
                print('wait timeout! proc %d exit!'%i)
                exit()
            time.sleep(1)
        writer = tf.python_io.TFRecordWriter('./train_input/tfrecords_file/train_dataset_%d.tfrecords'%i)
        while queue.qsize():
            batch = queue.get()                            # 从队列中获取一个样本
            example = tf.train.Example(features=tf.train.Features(feature={
                'feature':  tf.train.Feature(float_list=tf.train.FloatList(value=batch[0].flatten())),
                'label':    tf.train.Feature(int64_list=tf.train.Int64List(value=batch[1].flatten())),
                'mask':     tf.train.Feature(int64_list=tf.train.Int64List(value=batch[2].flatten())),
                'length':   tf.train.Feature(int64_list=tf.train.Int64List(value=[batch[3][0][0]]))
                #'feature_shape': tf.train.Feature(int64_list=tf.train.Int64List(value=np.array(batch[0].shape)))
            }))                                            # 这里将二维的feature label mask 转为一维的进行存储
            writer.write(example.SerializeToString())
        writer.close()

    start = time.time()
    queue = Queue(512)
    proc_record = []
    for i in range(10):
        p = Process(target=write_tfrecords, args=(queue, i)) #开10个进程用来写入数据
        p.start()
        proc_record.append(p)
    num = 0
    while True:
        try:
            batch = loader.next()                            # 获取一个样本, 压入队列
        except StopIteration:
            tf.logging.info('finished convert to tfrecords')
            break
        if batch is not None:
            queue.put(batch)
            num += 1
        else:
            break
    for p in proc_record:   p.join()                        # 等待所有进程结束
    print('num:', num)
    print('time:', time.time()-start)

程序写了一个多进程写入tfrecords, 在主进程中读取数据压入队列,再开辟10个进程从队列中读取数据, 因为我的loader.next加载数据比较长,所以在子进程中设置了循环等待。

在尝试过多线程, 应为python GIL的原因, 所以速度没有提升, 改成了多进程。

你可能感兴趣的:(tfrecords读取文件)