tensorflow数据读入、数据加载

先读下官网的api:

https://www.tensorflow.org/api_guides/python/reading_data

分为三种

  • placeholder,把数据feed进去,这个需要自己写数据迭代器和shuffle,还要控制epoch。
  • 读取文件,tfrecord,csv等
  • 预加载的文件,都进内存,小数据下使用。
    推荐从文件中读取,使用tfrecord,让tf自动load和shuf文件、还可以控制epoch。

将训练数据转换成tfrecord

def write_tfrecord(writer, char_ids, label_id):
     """
     :param writer: tf record writer
     :param char_ids: list
     :param label_id: int
     """
     example = tf.train.Example(features=tf.train.Features(feature={
         'char_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=char_ids)),
         'label_id': tf.train.Feature(int64_list=tf.train.Int64List(value=label_id))
     }))
     writer.write(example.SerializeToString())

writer = tf.python_io.TFRecordWriter(output_file_name)
for line open(your_files):
    char_ids = ...
    label_id = ...
    write_tfrecord(writer, char_ids, label_id)
writer.close()

读取tfrecord,并解码

  • 单纯解码
def test_read_tfrecords():
    filename = "./data/train.tfrecords"
    for serialized_example in tf.python_io.tf_record_iterator(filename):
        example = tf.train.Example()
        example.ParseFromString(serialized_example)
        # traverse the Example format to get data
        x = example.features.feature['char_ids'].int64_list.value
        y = example.features.feature['label_id'].int64_list.value
        # do something

将tfrecord、batch和train联系起来

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'char_ids': tf.FixedLenFeature([max_seq_len], tf.int64),
            'label_id': tf.FixedLenFeature([1], tf.int64)
        })

    char_ids = tf.cast(features['char_ids'], tf.int32)
    # label_ids = tf.one_hot(tf.cast(features['label_ids'], tf.int32)[0], len(data_set.label_dict))
    label_id = tf.cast(features['label_id'], tf.int32)[0]
    return char_ids, label_id

def inputs(filename_list, batch_size, num_epochs=None):
    filename_queue = tf.train.string_input_producer(filename_list, num_epochs=num_epochs)
    char_ids, label_id = read_and_decode(filename_queue)
    batch_char_ids, batch_label_id = tf.train.shuffle_batch(
        [char_ids, label_id], batch_size=batch_size, num_threads=12,
        capacity=1000 + 3 * batch_size,
        min_after_dequeue=1000)

    return batch_char_ids, batch_label_id

# main session
init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
    while not coord.should_stop():
        # Run training steps or whatever
        x_batch, y_batch = inputs(data_files, batch_size, num_epochs=50)
        train_op = ...
        sess.run(train_op)
except tf.errors.OutOfRangeError:
    print('Done training -- epoch limit reached')
finally:
    # When done, ask the threads to stop.
    coord.request_stop()
# Wait for threads to finish.
coord.join(threads)
sess.close()

你可能感兴趣的:(tensorflow数据读入、数据加载)