tfrecords:对音频数据的存储和读取

因为音频数据的不定长,存储与读取和图像不同,在此记录下,方便以后快速查看。

import os
import six
import random
import numpy as np
import tensorflow as tf

import data_input_tfrecord

fea_conf = dict([
    ('SOS_ID', 2),
    ('EOS_ID', 3),
    ('BATCH_SIZE', 2),
    ('FRAME_SIZE', 80),
    ('GRAPHEME_TARGET_SEQUENCE_LENGTH', 620),
])

def read_TFRecord(file_pattern, symbol):
  """read tfrecord."""

  file_pattern = os.path.join(file_pattern, '%s.tfrecords-*'%symbol)

  def ParseAndProcess(record):
    """Parses a serialized tf.Example record."""
    features = [
        ('uttid', tf.VarLenFeature(tf.string)),
        ('label', tf.VarLenFeature(tf.int64)),
        ('frames', tf.VarLenFeature(tf.float32)),
    ]
    example = tf.parse_single_example(record, dict(features))

    fval = {k: v.values for k, v in six.iteritems(example)}
    fval['frames'] = tf.reshape(
        fval['frames'], shape=[-1, fea_conf['FRAME_SIZE']])
    src_paddings = tf.ones([tf.shape(fval['frames'])[0]], dtype=tf.float32)

    fval['label'] = tf.reshape(fval['label'], shape=[1, -1])
    tgt_labels = tf.concat(
        [tf.cast(fval['label'], tf.int32),
        tf.fill([1,fea_conf['GRAPHEME_TARGET_SEQUENCE_LENGTH']-tf.shape(fval['label'])[-1]], fea_conf['EOS_ID'])],
        axis=1
    )
    tgt_ids = tf.concat(
        [tf.fill([1,1], fea_conf['SOS_ID']),
        tf.slice(tgt_labels, [0,0], [1, fea_conf['GRAPHEME_TARGET_SEQUENCE_LENGTH']-1])],
        axis=1
    )
    tgt_paddings = tf.concat(
        [tf.zeros([1, tf.shape(fval['label'])[-1]+1], dtype=tf.float32),
        tf.ones([1, fea_conf['GRAPHEME_TARGET_SEQUENCE_LENGTH']-tf.shape(fval['label'])[-1]-1], dtype=tf.float32)],
        axis=1
    )
    return fval['uttid'], tgt_ids, tgt_labels, tgt_paddings, fval['frames'], src_paddings

  def element_length_fn(uttids, tgt_ids, tgt_labels, tgt_paddings, frames, src_paddings, framed):
    return tf.shape(frames)[0]
  dataset_factory = tf.data.TFRecordDataset
  dataset = (
      tf.data.Dataset.list_files(
          file_pattern.lstrip('tfrecord:'), shuffle=True).apply(
        tf.data.experimental.parallel_interleave(
                  dataset_factory,
                  cycle_length=1,
                  sloppy=True)))
  dataset = dataset.map(
      ParseAndProcess, num_parallel_calls=4)
  dataset = dataset.shuffle(512)
  dataset = dataset.apply(
      tf.data.experimental.bucket_by_sequence_length(
          element_length_func=element_length_fn,
          bucket_boundaries=[500, 1000],
          bucket_batch_sizes= [512, 256, 64],
          padded_shapes=([None], [1, None], [1, None], [1, None], [None, 80], [None], []),
          ))
  #dataset = dataset.padded_batch(
  #   fea_conf['BATCH_SIZE'], padded_shapes=([None], [1, None], [1, None], [1, None], [None, 80], [None]))
  dataset = dataset.prefetch(buffer_size=1)
  dataset = dataset.repeat()
  iterator = dataset.make_one_shot_iterator()
  input_batch = iterator.get_next()

  return input_batch[0], input_batch[1], input_batch[2], input_batch[3], input_batch[4], input_batch[5], input_batch[6]

def write_TFRecord(record_file_path, symbol):
  """ write tfrecord."""
  uttid = 0
  recordio_writers = []
  data_loader = data_input_tfrecord.SpeechLoader(
      "./%s_data.conf"%symbol)
  end = 1 if symbol == 'test' else 100
  for s in range(end):
    filepath = "%s/%s.tfrecords-%5.5d-of-%5.5d"%(record_file_path,symbol, s, end)
    recordio_writers += [tf.python_io.TFRecordWriter(filepath)]
  while True:
    try:
      (frames, tgt_labels,
       tgt_padding, tgt_seq_length) = data_loader.next()
    except Exception, e:
      print('filished...............')
      print('exception:', e)
      break
    else:
      print("========uttid:", uttid)
      # uttid_str_list.shape(): [1]
      uttid_str = [tf.compat.as_bytes('_'.join(['utt',str(uttid)]))]
      uttid += 1
      # frames.shape(): [frame_num, fea_dim]
      # --------> [frame_num*fea_dim]
      flat_frames = frames.flatten(order='C')
      # tgt_label.shape():[frame_num, 1]
      # --------> [frame_num*1]
      flat_label = tgt_labels.flatten(order='C')
      feature = {
          'uttid':tf.train.Feature(bytes_list=tf.train.BytesList(value=uttid_str)),
          'label':tf.train.Feature(int64_list=tf.train.Int64List(value=flat_label)),
          'frames':tf.train.Feature(float_list=tf.train.FloatList(value=flat_frames))
      }
      ex = tf.train.Example(features=tf.train.Features(feature=feature))
      outf = recordio_writers[random.randint(0, len(recordio_writers) - 1)]
      outf.write(ex.SerializeToString())
      
  for f in recordio_writers:
    f.close()

if __name__ == '__main__':
  symbol='train'
  record_file_path = '/data/kevin/speech_data/hkust/%s_tfrecord'%(symbol)
  # write tfrecord
  write_TFRecord(record_file_path, symbol)
  exit()

  # read tfrecord
  with tf.Session() as sess:
    (utt_ids, tgt_ids, tgt_labels, tgt_paddings, src_frames,
        src_paddings) = read_TFRecord(record_file_path, symbol)

    (utt_id, tgtid, tgtlabel, tgtpad, frame,
        framepad) = sess.run([utt_ids, tgt_ids, tgt_labels, tgt_paddings, src_frames, src_paddings])

    print('utt_ids:', utt_id)
    print('tgt_label:', tgtlabel)
    print('tgt_label.shape:', tgtlabel.shape)
    print('tgt_ids:', tgtid)
    print('tgt_ids.shape:', tgtid.shape)
    print('tgt_paddings:', tgtpad)
    print('tgt_paddings.shape:', tgtpad.shape)
    print('src_frame:', frame)
    print('src_frame.shape:', frame.shape)
    print('src_paddings:', framepad)

读取数据的时候用了固定batch_size和非batch_size 两种方式。

你可能感兴趣的:(tfrecords:对音频数据的存储和读取)