因为音频数据的不定长,存储与读取和图像不同,在此记录下,方便以后快速查看。
import os
import six
import random
import numpy as np
import tensorflow as tf
import data_input_tfrecord
fea_conf = dict([
('SOS_ID', 2),
('EOS_ID', 3),
('BATCH_SIZE', 2),
('FRAME_SIZE', 80),
('GRAPHEME_TARGET_SEQUENCE_LENGTH', 620),
])
def read_TFRecord(file_pattern, symbol):
"""read tfrecord."""
file_pattern = os.path.join(file_pattern, '%s.tfrecords-*'%symbol)
def ParseAndProcess(record):
"""Parses a serialized tf.Example record."""
features = [
('uttid', tf.VarLenFeature(tf.string)),
('label', tf.VarLenFeature(tf.int64)),
('frames', tf.VarLenFeature(tf.float32)),
]
example = tf.parse_single_example(record, dict(features))
fval = {k: v.values for k, v in six.iteritems(example)}
fval['frames'] = tf.reshape(
fval['frames'], shape=[-1, fea_conf['FRAME_SIZE']])
src_paddings = tf.ones([tf.shape(fval['frames'])[0]], dtype=tf.float32)
fval['label'] = tf.reshape(fval['label'], shape=[1, -1])
tgt_labels = tf.concat(
[tf.cast(fval['label'], tf.int32),
tf.fill([1,fea_conf['GRAPHEME_TARGET_SEQUENCE_LENGTH']-tf.shape(fval['label'])[-1]], fea_conf['EOS_ID'])],
axis=1
)
tgt_ids = tf.concat(
[tf.fill([1,1], fea_conf['SOS_ID']),
tf.slice(tgt_labels, [0,0], [1, fea_conf['GRAPHEME_TARGET_SEQUENCE_LENGTH']-1])],
axis=1
)
tgt_paddings = tf.concat(
[tf.zeros([1, tf.shape(fval['label'])[-1]+1], dtype=tf.float32),
tf.ones([1, fea_conf['GRAPHEME_TARGET_SEQUENCE_LENGTH']-tf.shape(fval['label'])[-1]-1], dtype=tf.float32)],
axis=1
)
return fval['uttid'], tgt_ids, tgt_labels, tgt_paddings, fval['frames'], src_paddings
def element_length_fn(uttids, tgt_ids, tgt_labels, tgt_paddings, frames, src_paddings, framed):
return tf.shape(frames)[0]
dataset_factory = tf.data.TFRecordDataset
dataset = (
tf.data.Dataset.list_files(
file_pattern.lstrip('tfrecord:'), shuffle=True).apply(
tf.data.experimental.parallel_interleave(
dataset_factory,
cycle_length=1,
sloppy=True)))
dataset = dataset.map(
ParseAndProcess, num_parallel_calls=4)
dataset = dataset.shuffle(512)
dataset = dataset.apply(
tf.data.experimental.bucket_by_sequence_length(
element_length_func=element_length_fn,
bucket_boundaries=[500, 1000],
bucket_batch_sizes= [512, 256, 64],
padded_shapes=([None], [1, None], [1, None], [1, None], [None, 80], [None], []),
))
#dataset = dataset.padded_batch(
# fea_conf['BATCH_SIZE'], padded_shapes=([None], [1, None], [1, None], [1, None], [None, 80], [None]))
dataset = dataset.prefetch(buffer_size=1)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
input_batch = iterator.get_next()
return input_batch[0], input_batch[1], input_batch[2], input_batch[3], input_batch[4], input_batch[5], input_batch[6]
def write_TFRecord(record_file_path, symbol):
""" write tfrecord."""
uttid = 0
recordio_writers = []
data_loader = data_input_tfrecord.SpeechLoader(
"./%s_data.conf"%symbol)
end = 1 if symbol == 'test' else 100
for s in range(end):
filepath = "%s/%s.tfrecords-%5.5d-of-%5.5d"%(record_file_path,symbol, s, end)
recordio_writers += [tf.python_io.TFRecordWriter(filepath)]
while True:
try:
(frames, tgt_labels,
tgt_padding, tgt_seq_length) = data_loader.next()
except Exception, e:
print('filished...............')
print('exception:', e)
break
else:
print("========uttid:", uttid)
# uttid_str_list.shape(): [1]
uttid_str = [tf.compat.as_bytes('_'.join(['utt',str(uttid)]))]
uttid += 1
# frames.shape(): [frame_num, fea_dim]
# --------> [frame_num*fea_dim]
flat_frames = frames.flatten(order='C')
# tgt_label.shape():[frame_num, 1]
# --------> [frame_num*1]
flat_label = tgt_labels.flatten(order='C')
feature = {
'uttid':tf.train.Feature(bytes_list=tf.train.BytesList(value=uttid_str)),
'label':tf.train.Feature(int64_list=tf.train.Int64List(value=flat_label)),
'frames':tf.train.Feature(float_list=tf.train.FloatList(value=flat_frames))
}
ex = tf.train.Example(features=tf.train.Features(feature=feature))
outf = recordio_writers[random.randint(0, len(recordio_writers) - 1)]
outf.write(ex.SerializeToString())
for f in recordio_writers:
f.close()
if __name__ == '__main__':
symbol='train'
record_file_path = '/data/kevin/speech_data/hkust/%s_tfrecord'%(symbol)
# write tfrecord
write_TFRecord(record_file_path, symbol)
exit()
# read tfrecord
with tf.Session() as sess:
(utt_ids, tgt_ids, tgt_labels, tgt_paddings, src_frames,
src_paddings) = read_TFRecord(record_file_path, symbol)
(utt_id, tgtid, tgtlabel, tgtpad, frame,
framepad) = sess.run([utt_ids, tgt_ids, tgt_labels, tgt_paddings, src_frames, src_paddings])
print('utt_ids:', utt_id)
print('tgt_label:', tgtlabel)
print('tgt_label.shape:', tgtlabel.shape)
print('tgt_ids:', tgtid)
print('tgt_ids.shape:', tgtid.shape)
print('tgt_paddings:', tgtpad)
print('tgt_paddings.shape:', tgtpad.shape)
print('src_frame:', frame)
print('src_frame.shape:', frame.shape)
print('src_paddings:', framepad)
读取数据的时候用了固定batch_size和非batch_size 两种方式。