tensorflow读取tfrecord数据集

tf运行时提示:

start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module

2019-06-01更新,没想到随便贴个日志那么多人看,就详细更新下新接口用法:
之前读tfrecords文件时不是会用tensorflow的队例来去读嘛,sess.run完最后还要close这个queue, 稍微麻烦了点。官方更新的接口是用 tf.data.TFRecordDataset直接读出数据集dataset,用dataset生成iterator就得到要run的tensor了。详细见下面:

def get_input_data(input_file, seq_length, batch_size, num_labels):
    def parser(record):
            name_to_features = {
                    "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
                    "label_ids": tf.FixedLenFeature([num_labels], tf.int64)
            }

            example = tf.parse_single_example(record, features=name_to_features)
            input_ids = example["input_ids"]
            labels = example["label_ids"]

            return input_ids, labels
    
    dataset = tf.data.TFRecordDataset(input_file)
    dataset = dataset.map(parser).repeat().batch(batch_size).shuffle(buffer_size=1000)
    iterator = dataset.make_one_shot_iterator()
    input_ids,  labels = iterator.get_next()
    
    return input_ids, labels

input_ids, labels = get_input_data(input_file, seq_len, batch_size, num_labels)

#以下ses.run一下就得到用来给feed dict的numpy array了
ids, mask, segment,y = sess.run([input_ids, labels])

可以参考我改写版本的BERT模型

官网demo

The tf.data API supports a variety of file formats so that you can process large datasets that do not fit in memory. For example, the TFRecord file format is a simple record-oriented binary format that many TensorFlow applications use for training data. The tf.data.TFRecordDataset class enables you to stream over the contents of one or more TFRecord files as part of an input pipeline.

# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]dataset = tf.data.TFRecordDataset(filenames)

The filenames argument to the TFRecordDataset initializer can either be a string, a list of strings, or a tf.Tensor of strings. Therefore if you have two sets of files for training and validation purposes, you can create a factory method that produces the dataset, taking filenames as an input argument:

def make_dataset(filenames):
  dataset = tf.data.TFRecordDataset(filenames)
  dataset = dataset.map(...)  # Parse the record into tensors.  
  dataset = dataset.repeat()  # Repeat the input indefinitely.  
  dataset = dataset.batch(32)
  training_dataset = make_dataset(["/var/data/training1.tfrecord", ...])
  validation_dataset = make_dataset(["/var/data/validation1.tfrecord", ...])

你可能感兴趣的:(tensorflow读取tfrecord数据集)