tf.data获取数据

承接上文, 将数据存储为tfrecords文件之后, 在从tfrecords文件中读取数据训练模型, 这里尝试使用tf.data数据读取方式。利用tf.data读取数据能加快数据读取过程。

def read_and_decode(loader, handle, num_epochs=1):
    """ read tfrecord format data"""
    batch_size = int(loader.batch_size() / FLAGS.gpu_num)
    feature_size = model_settings['fingerprint_size']

    def parse_exmp(serialized_example):
        features = tf.parse_single_example(serialized_example, features={
                                                    'feature':  tf.VarLenFeature(tf.float32),
                                                    'label':    tf.VarLenFeature(tf.int64),
                                                    'mask':     tf.VarLenFeature(tf.int64),
                                                    'length':   tf.FixedLenFeature((),tf.int64)
                                                    })

        length = tf.cast(features['length'], tf.int32)
        feature = tf.sparse_tensor_to_dense(features['feature'])
        feature = tf.reshape(feature, [length, feature_size])
        label = tf.sparse_tensor_to_dense(features['label'])
        mask = tf.sparse_tensor_to_dense(features['mask'])
        return feature, label, mask, length

    filenames = ['./train_input/tfrecords_file/train_dataset_%d.tfrecords'%i for i in range(10)]
    dataset = tf.contrib.data.TFRecordDataset(filenames)
    dataset = dataset.map(parse_exmp, num_parallel_calls=64)
    dataset = dataset.prefetch(buffer_size=batch_size)
    dataset = dataset.shuffle(64).repeat(num_epochs).padded_batch(batch_size, padded_shapes=([None, feature_size],[None],[None],[]))
    train_iterator = dataset.make_initializable_iterator()

    iterator = tf.contrib.data.Iterator.from_string_handle(handle, \
        dataset.output_types, dataset.output_shapes)

    batch_data, batch_label, batch_mask, batch_length = iterator.get_next()

    if FLAGS.ctc_loss == True:
        return train_iterator,tf.transpose(batch_data, (1,0,2)), batch_label, batch_mask, batch_length
    else:
        return train_iterator,tf.transpose(batch_data, (1,0,2)), tf.transpose(batch_label, (1,0)), tf.transpose(batch_mask, (1,0)), batch_length

你可能感兴趣的:(tf.data获取数据)