参考博客
语音识别
框架
本来最开始用Kaldi进行语音识别,但我不清楚dnn的框架,所以改用python
用CPU跑的
算法CNN_CTC
train.py
import os
import tensorflow as tf
from utils import get_data, data_hparams
from keras.callbacks import ModelCheckpoint
# 0.准备训练所需数据------------------------------
data_args = data_hparams()
data_args.data_type = 'train'
data_args.data_path = 'data/'
data_args.thchs30 = True
data_args.batch_size = 4
data_args.data_length = 10
# data_args.data_length = None
data_args.shuffle = False
train_data = get_data(data_args)
# 0.准备验证所需数据------------------------------
data_args = data_hparams()
data_args.data_type = 'dev'
data_args.data_path = 'data/'
data_args.thchs30 = True
data_args.batch_size = 4
# data_args.data_length = None
data_args.data_length = 10
data_args.shuffle = False
dev_data = get_data(data_args)
# 1.声学模型训练-----------------------------------
from model_speech.cnn_ctc import Am, am_hparams
am_args = am_hparams()
am_args.vocab_size = len(train_data.am_vocab)
# am_args.gpu_nums = 1
am_args.lr = 0.0008
am_args.is_training = True
am = Am(am_args)
if os.path.exists('logs_am/model.h5'):
print('load acoustic model...')
am.ctc_model.load_weights('logs_am/model.h5')
epochs = 10
batch_num = len(train_data.wav_lst) // train_data.batch_size
for k in range(epochs):
print('this is the', k+1, 'th epochs trainning !!!')
#shuffle(shuffle_list)
batch = train_data.get_am_batch()
am.ctc_model.fit_generator(batch, steps_per_epoch=batch_num, epochs=1)
am.ctc_model.save_weights('logs_am/model.h5')
# 2.语言模型训练-------------------------------------------
from model_language.transformer import Lm, lm_hparams
lm_args = lm_hparams()
lm_args.num_heads = 8
lm_args.num_blocks = 6
lm_args.input_vocab_size = len(train_data.pny_vocab)
lm_args.label_vocab_size = len(train_data.han_vocab)
lm_args.max_length = 100
lm_args.hidden_units = 512
lm_args.dropout_rate = 0.2
lm_args.lr = 0.0003
lm_args.is_training = True
lm = Lm(lm_args)
epochs = 10
with lm.graph.as_default():
saver =tf.train.Saver()
with tf.Session(graph=lm.graph) as sess:
merged = tf.summary.merge_all()
sess.run(tf.global_variables_initializer())
add_num = 0
if os.path.exists('logs_lm/checkpoint'):
print('loading language model...')
latest = tf.train.latest_checkpoint('logs_lm')
add_num = int(latest.split('_')[-1])
saver.restore(sess, latest)
writer = tf.summary.FileWriter('logs_lm/tensorboard', tf.get_default_graph())
for k in range(epochs):
total_loss = 0
batch = train_data.get_lm_batch()
for i in range(batch_num):
input_batch, label_batch = next(batch)
feed = {
lm.x: input_batch, lm.y: label_batch}
cost,_ = sess.run([lm.mean_loss,lm.train_op], feed_dict=feed)
total_loss += cost
if (k * batch_num + i) % 10 == 0:
rs=sess.run(merged, feed_dict=feed)
writer.add_summary(rs, k * batch_num + i)
print('epochs', k+1, ': average loss = ', total_loss/batch_num)
saver.save(sess, 'logs_lm/model_%d' % (epochs + add_num))
writer.close()
test.py
#coding=utf-8
import os
import difflib
import tensorflow as tf
import numpy as np
from utils import decode_ctc, GetEditDistance
# 0.准备解码所需字典,参数需和训练一致,也可以将字典保存到本地,直接进行读取
from utils import get_data, data_hparams
data_args = data_hparams()
train_data = get_data(data_args)
# 1.声学模型-----------------------------------
from model_speech.cnn_ctc import Am, am_hparams
am_args = am_hparams()
am_args.vocab_size = len(train_data.am_vocab)
am = Am(am_args)
print('loading acoustic model...')
am.ctc_model.load_weights('logs_am/model.h5')
# 2.语言模型-------------------------------------------
from model_language.transformer import Lm, lm_hparams
lm_args = lm_hparams()
lm_args.input_vocab_size = len(train_data.pny_vocab)
lm_args.label_vocab_size = len(train_data.han_vocab)
lm_args.dropout_rate = 0.
print('loading language model...')
lm = Lm(lm_args)
sess = tf.Session(graph=lm.graph)
with lm.graph.as_default():
saver =tf.train.Saver()
with sess.as_default():
latest = tf.train.latest_checkpoint('logs_lm')
saver.restore(sess, latest)
# 3. 准备测试所需数据, 不必和训练数据一致,通过设置data_args.data_type测试,
# 此处应设为'test',我用了'train'因为演示模型较小,如果使用'test'看不出效果,
# 且会出现未出现的词。
data_args.data_type = 'train'
data_args.shuffle = False
data_args.batch_size = 1
test_data = get_data(data_args)
# 4. 进行测试-------------------------------------------
am_batch = test_data.get_am_batch()
word_num = 0
word_error_num = 0
for i in range(5):
print('\n the ', i, 'th example.')
# 载入训练好的模型,并进行识别
inputs, _ = next(am_batch)
x = inputs['the_inputs']
y = test_data.pny_lst[i]
result = am.model.predict(x, steps=1)
# 将数字结果转化为文本结果
_, text = decode_ctc(result, train_data.am_vocab)
text = ' '.join(text)
print('文本结果:', text)
print('原文结果:', ' '.join(y))
with sess.as_default():
text = text.strip('\n').split(' ')
x = np.array([train_data.pny_vocab.index(pny) for pny in text])
x = x.reshape(1, -1)
preds = sess.run(lm.preds, {
lm.x: x})
label = test_data.han_lst[i]
got = ''.join(train_data.han_vocab[idx] for idx in preds[0])
print('原文汉字:', label)
print('识别结果:', got)
word_error_num += min(len(label), GetEditDistance(label, got))
word_num += len(label)
print('词错误率:', word_error_num / word_num)
sess.close()
数据准备
utils.py
import os
import difflib
import numpy as np
import tensorflow as tf
import scipy.io.wavfile as wav
from tqdm import tqdm
from scipy.fftpack import fft
from python_speech_features import mfcc
from random import shuffle
from keras import backend as K
def data_hparams():
params = tf.contrib.training.HParams(
# vocab
data_type='train',
data_path='data/',
thchs30=True,
aishell=True,
prime=True,
stcmd=True,
batch_size=1,
data_length=10,
shuffle=True)
return params
class get_data():
def __init__(self, args):
self.data_type = args.data_type
self.data_path = args.data_path
self.thchs30 = args.thchs30
self.aishell = args.aishell
self.prime = args.prime
self.stcmd = args.stcmd
self.data_length = args.data_length
self.batch_size = args.batch_size
self.shuffle = args.shuffle
self.source_init()
def source_init(self):
print('get source list...')
read_files = []
if self.data_type == 'train':
if self.thchs30 == True:
read_files.append('thchs_train.txt')
if self.aishell == True:
read_files.append('aishell_train.txt')
if self.prime == True:
read_files.append('prime.txt')
if self.stcmd == True:
read_files.append('stcmd.txt')
elif self.data_type == 'dev':
if self.thchs30 == True:
read_files.append('thchs_dev.txt')
if self.aishell == True:
read_files.append('aishell_dev.txt')
elif self.data_type == 'test':
if self.thchs30 == True:
read_files.append('thchs_test.txt')
if self.aishell == True:
read_files.append('aishell_test.txt')
self.wav_lst = []
self.pny_lst = []
self.han_lst = []
for file in read_files:
print('load ', file, ' data...')
sub_file = 'data/' + file
with open(sub_file, 'r', encoding='utf-8-sig') as f:
data = f.readlines()
for line in tqdm(data):
wav_file, pny, han = line.split('\t')
self.wav_lst.append(wav_file)
self.pny_lst.append(pny.split(' '))
self.han_lst.append(han.strip('\n'))
if self.data_length:
self.wav_lst = self.wav_lst[:self.data_length]
self.pny_lst = self.pny_lst[:self.data_length]
self.han_lst = self.han_lst[:self.data_length]
print('make am vocab...')
self.am_vocab = self.mk_am_vocab(self.pny_lst)
print('make lm pinyin vocab...')
self.pny_vocab = self.mk_lm_pny_vocab(self.pny_lst)
print('make lm hanzi vocab...')
self.han_vocab = self.mk_lm_han_vocab(self.han_lst)
def get_am_batch(self):
shuffle_list = [i for i in range(len(self.wav_lst))]
while 1:
if self.shuffle == True:
shuffle(shuffle_list)
for i in range(len(self.wav_lst) // self.batch_size):
wav_data_lst = []
label_data_lst = []
begin = i * self.batch_size
end = begin + self.batch_size
sub_list = shuffle_list[begin:end]
for index in sub_list:
fbank = compute_fbank(self.data_path + self.wav_lst[index])
pad_fbank = np.zeros((fbank.shape[0] // 8 * 8 + 8, fbank.shape[1]))
pad_fbank[:fbank.shape[0], :] = fbank
label = self.pny2id(self.pny_lst[index], self.am_vocab)
label_ctc_len = self.ctc_len(label)
if pad_fbank.shape[0] // 8 >= label_ctc_len:
wav_data_lst.append(pad_fbank)
label_data_lst.append(label)
pad_wav_data, input_length = self.wav_padding(wav_data_lst)
pad_label_data, label_length = self.label_padding(label_data_lst)
inputs = {
'the_inputs': pad_wav_data,
'the_labels': pad_label_data,
'input_length': input_length,
'label_length': label_length,
}
outputs = {
'ctc': np.zeros(pad_wav_data.shape[0], )}
yield inputs, outputs
def get_lm_batch(self):
batch_num = len(self.pny_lst) // self.batch_size
for k in range(batch_num):
begin = k * self.batch_size
end = begin + self.batch_size
input_batch = self.pny_lst[begin:end]
label_batch = self.han_lst[begin:end]
max_len = max([len(line) for line in input_batch])
input_batch = np.array(
[self.pny2id(line, self.pny_vocab) + [0] * (max_len - len(line)) for line in input_batch])
label_batch = np.array(
[self.han2id(line, self.han_vocab) + [0] * (max_len - len(line)) for line in label_batch])
yield input_batch, label_batch
def pny2id(self, line, vocab):
return [vocab.index(pny) for pny in line]
def han2id(self, line, vocab):
return [vocab.index(han) for han in line]
def wav_padding(self, wav_data_lst):
wav_lens = [len(data) for data in wav_data_lst]
wav_max_len = max(wav_lens)
wav_lens = np.array([leng // 8 for leng in wav_lens])
new_wav_data_lst = np.zeros((len(wav_data_lst), wav_max_len, 200, 1))
for i in range(len(wav_data_lst)):
new_wav_data_lst[i, :wav_data_lst[i].shape[0], :, 0] = wav_data_lst[i]
return new_wav_data_lst, wav_lens
def label_padding(self, label_data_lst):
label_lens = np.array([len(label) for label in label_data_lst])
max_label_len = max(label_lens)
new_label_data_lst = np.zeros((len(label_data_lst), max_label_len))
for i in range(len(label_data_lst)):
new_label_data_lst[i][:len(label_data_lst[i])] = label_data_lst[i]
return new_label_data_lst, label_lens
def mk_am_vocab(self, data):
vocab = []
for line in tqdm(data):
line = line
for pny in line:
if pny not in vocab:
vocab.append(pny)
vocab.append('_')
return vocab
def mk_lm_pny_vocab(self, data):
vocab = ['' ]
for line in tqdm(data):
for pny in line:
if pny not in vocab:
vocab.append(pny)
return vocab
def mk_lm_han_vocab(self, data):
vocab = ['' ]
for line in tqdm(data):
line = ''.join(line.split(' '))
for han in line:
if han not in vocab:
vocab.append(han)
return vocab
def ctc_len(self, label):
add_len = 0
label_len = len(label)
for i in range(label_len - 1):
if label[i] == label[i + 1]:
add_len += 1
return label_len + add_len
# 对音频文件提取mfcc特征
def compute_mfcc(file):
fs, audio = wav.read(file)
mfcc_feat = mfcc(audio, samplerate=fs, numcep=26)
mfcc_feat = mfcc_feat[::3]
mfcc_feat = np.transpose(mfcc_feat)
return mfcc_feat
# 获取信号的时频图
def compute_fbank(file):
x = np.linspace(0, 400 - 1, 400, dtype=np.int64)
w = 0.54 - 0.46 * np.cos(2 * np.pi * (x) / (400 - 1)) # 汉明窗
fs, wavsignal = wav.read(file)
# wav波形 加时间窗以及时移10ms
time_window = 25 # 单位ms
wav_arr = np.array(wavsignal)
range0_end = int(len(wavsignal) / fs * 1000 - time_window) // 10 + 1 # 计算循环终止的位置,也就是最终生成的窗数
data_input = np.zeros((range0_end, 200), dtype=np.float) # 用于存放最终的频率特征数据
data_line = np.zeros((1, 400), dtype=np.float)
for i in range(0, range0_end):
p_start = i * 160
p_end = p_start + 400
data_line = wav_arr[p_start:p_end]
data_line = data_line * w # 加窗
data_line = np.abs(fft(data_line))
data_input[i] = data_line[0:200] # 设置为400除以2的值(即200)是取一半数据,因为是对称的
data_input = np.log(data_input + 1)
# data_input = data_input[::]
return data_input
# word error rate------------------------------------
def GetEditDistance(str1, str2):
leven_cost = 0
s = difflib.SequenceMatcher(None, str1, str2)
for tag, i1, i2, j1, j2 in s.get_opcodes():
if tag == 'replace':
leven_cost += max(i2-i1, j2-j1)
elif tag == 'insert':
leven_cost += (j2-j1)
elif tag == 'delete':
leven_cost += (i2-i1)
return leven_cost
# 定义解码器------------------------------------
def decode_ctc(num_result, num2word):
result = num_result[:, :, :]
in_len = np.zeros((1), dtype = np.int32)
in_len[0] = result.shape[1]
r = K.ctc_decode(result, in_len, greedy = True, beam_width=10, top_paths=1)
r1 = K.get_value(r[0][0])
r1 = r1[0]
text = []
for i in r1:
text.append(num2word[i])
return r1, text
原github链接
交流 q 2531996920