Viterbi解码理论与实战
笔者最近着手研究基于HMM的语音识别系统,之前博文基于C++工具手写了提取MFCC语音特征(具体可以观看之前博文),最同时,也对基于GMM-HMM的语音识别训练过程进行了理论推导,现本文对基于Viterbi的解码方法进行详细的研究。
曾看过很多语音识别书(余栋的语音识别实战、陈果 果的kaldi实战以及张雪英的数位语音处理等书)、课程(七月在线、深蓝学院等课程),上述研究对于Viterbi应用于语音识别仅停留于理论阶段,笔者针对他们的理论应用于语音识别有保留的赞同,以下则是本人对Viterbi算法的理论理解与工程实现方法,持有不同意见的读者不妨对本人提出建议:
一、Viterbi解码理论部分
1.1 余栋书对Viterbi算法的理论介绍
首先看上书对Viterbi解码提供的伪代码部分:
对上述伪代码笔者曾一头雾水,对其中很多念理解不清楚,实际上其具体含义可以理解为:从状态初始至结尾找出概率最大的路径,该路径可以通过回溯找出对应状态序列,最后进行反转得到正序语音序列。
对于此伪代码,我相信如果不经历过自己手写过Viterbi算法,很难理解其中具体的实现过程,现笔者开始对其进行深入剖析。
1.2 个人见解
值得指出的是,笔者很好在相关书上找出究竟是怎样对Viterbi概率进行准确的求解,即很少有介绍状态之间概率究竟是如果进行跳转以及计算的。值得指出的是:状态之间的计算仅是当前状态与之前一帧部分状态之间的计算,因为累积会造成概率小时或者概率爆炸的问题,故对其进行log计算,故累积变为累加,其可解决上述问题。
上述中何为“当前状态与之前一帧部分状态之间”,其实这句话对于理解Viterbi计算有很重要的作用,部分即为之前一帧仅有部分状态与当前状态有关系(即为有跳转)。必须说明的是,最后的每个状态的累加结果包含三部分:弧上概率(为防止其他概率不存在,一般会对弧上概率进行初始化,可以为负无穷或正无穷)、gmm概率(即为特征与高斯混合模型得到的每帧中每个状态对应的概率密度函数,pdf)以及状态转移概率(状态转移即对应与当前状态想关的之前状态)。
因为Viterbi算法涉及回溯找出概率最大对应的状态序列(实际上说是状态序列实际上是不太准确的,因为在解码过程中,每个HMM可能有N个状态,因此实际解码过程中使用的是弧序号来替代状态号,当然二者之间存在互相映射关系),因此对于回溯过程中容器的选择至关重要,其主要涉及弧号以及概率的保存(可以设置自定义变量保存二者),然后使用矩阵将自定义变量按顺序保存至矩阵容器中,因此后文实现过程中使用Matrix
二、Viterbi解码实战部分
2.1 特征数据
首先本文参考是的哥伦比亚大学的语音识别课程进行Viterbi解码,其首先针对一句语料进行解码,最后提取该句语料特征:68*12,其中,68表示帧数,12表示MFCC特征。
2.2 symbol 序列
音素序列即为建模单元,针对不同语音识别系统,可能有不同的建模单元,常见的有:状态、音素、音节、字等,笔者使用的symbol词典格式如下所示:
音素与其序列之间为互相映射关系,后期可以直接通过音素序号得到对应的音素,而音素序号可以通过训练得到模型得到状态对应的音素,音素序号与状态也为互相映射关系,但正如上文所述,实际应用过程中使用弧序号进行存储矩阵容器,然而弧号可与状态号进行互相映射。
很绕脑,但由上文可得:可以通过弧号得到对应的音素序列。
2.3 graph 序列
由下表第四行可知,每行以此保存着初始状态、跳转状态、对应GMM序号以及对应的音素。
值得说明的是:graph表中最后保存的是HMM状态链的终止状态,终止状态不发生状态跳转,其结果如下:
由上图可知,本文使用的graph序列终止状态为47与84两个不同的状态,后期进行Viterbi计算得到最大的终止状态。
2.4 gmm 训练结果
通过EM算法不断迭代存储GMM模型,以此包含:GMM对应的状态序号、各GMM权重、以及均值与方差,,其中均值与方差存储结果如下图所示:
因为语音识别常使用对角GMM对特征进行计算,其中每行有24个元素(均值与方差维度必须与特征维度一直,否则无法计算每帧属于某个状态的pdf),奇数列、偶数列分别表示方差与均值,因为是对角线元素,故方差值必为正值,读者可以参考哥大的代码,建立graph类,以下为模型建立核心过程:
string Graph::read(istream& inStrm, const string& name) {
clear();
string retStr;
string lineStr;
vector fieldList;
while (true) {
int peekChar = inStrm.peek();
if (peekChar != '#')
break;
getline(inStrm, lineStr);
split_string(lineStr, fieldList);
if ((fieldList.size() == 3) && (fieldList[0] == "#") &&
(fieldList[1] == "name:")) {
if (!name.empty() && (fieldList[2] != name))
throw runtime_error(str(format("Unexpected FSM name: %s/%s") %
name % fieldList[2]));
if (!retStr.empty())
throw runtime_error(str(format("FSM has two names: %s/%s") %
retStr % fieldList[2]));
retStr = fieldList[2];
}
}
int lastIdx = -1;
vector> arcList;
double logFactor = -log(10.0);
while (true) {
int peekChar = inStrm.peek();
if ((peekChar == '#') || (peekChar == EOF))
break;
getline(inStrm, lineStr);
split_string(lineStr, fieldList);
if (!fieldList.size())
continue;
try {
int srcIdx = lexical_cast(fieldList[0]);
if (srcIdx < 0)
throw runtime_error("Negative state index in FSM: " + lineStr);
if (m_start == -1)
m_start = srcIdx;
if (srcIdx > lastIdx)
lastIdx = srcIdx;
if (fieldList.size() <= 2) {
double logProb = (fieldList.size() > 1) ?
lexical_cast(fieldList[1]) * logFactor : 0.0;
if (m_finalLogProbs.find(srcIdx) != m_finalLogProbs.end())
throw runtime_error("Dup final state in FSM: " + lineStr);
m_finalLogProbs[srcIdx] = logProb;
continue;
}
if ((fieldList.size() == 3) || (fieldList.size() > 5))
throw runtime_error("Invalid num fields in FSM: " + lineStr);
unsigned dstIdx = lexical_cast(fieldList[1]);
if (dstIdx < 0)
throw runtime_error("Negative state index in FSM: " + lineStr);
if (dstIdx > lastIdx)
lastIdx = dstIdx;
int gmmIdx = -1;
const string& gmmStr = fieldList[2];
if ((gmmStr.length() >= 3) && (gmmStr.length() <= 9) &&
(gmmStr[0] == '<') && (gmmStr[gmmStr.length() - 1] == '>') &&
(string("epsilon").substr(0, gmmStr.length() - 2) ==
gmmStr.substr(1, gmmStr.length() - 2))) {
;
}
else {
gmmIdx = lexical_cast(gmmStr);
if (gmmIdx < 0)
throw runtime_error("Negative GMM index in FSM: " +
lineStr);
int wordIdx = !m_symTable->empty() ?
m_symTable->get_index(fieldList[3]) : 0;
if (wordIdx < 0)
throw runtime_error("OOV word in FSM: " + lineStr);
double logProb = (fieldList.size() > 4) ?
lexical_cast(fieldList[4]) * logFactor : 0.0;
Arc arc(dstIdx, gmmIdx, wordIdx, logProb);
arcList.push_back(make_pair(srcIdx, arc));
}
}
catch (bad_lexical_cast&)
{
throw runtime_error("Invalid type for field in FSM: " + lineStr);
}
}
if (m_start < 0)
throw runtime_error("Empty FSM.");
//lastIdx:122;
int stateCnt = lastIdx + 1;
m_stateMap.reserve(stateCnt);
m_arcList.reserve(arcList.size());
sort(arcList.begin(), arcList.end(), CompareArcs());
for (int arcIdx = 0; arcIdx < (int)arcList.size(); ++arcIdx) {
m_arcList.push_back(arcList[arcIdx].second);
//arcList[arcIdx].second.pringResults();
int srcIdx = arcList[arcIdx].first;
while ((int)m_stateMap.size() <= srcIdx)
m_stateMap.push_back(arcIdx);
}
//printVector<>(m_stateMap);
while ((int)m_stateMap.size() < stateCnt)
m_stateMap.push_back(arcList.size());
assert(((int)m_stateMap.size() == stateCnt) &&
(m_arcList.size() == arcList.size()));
for (int stateIdx = 0; stateIdx < stateCnt; ++stateIdx) {
int minArcIdx = get_min_arc_index(stateIdx);
int maxArcIdx = get_max_arc_index(stateIdx);
for (int arcIdx = minArcIdx; arcIdx < maxArcIdx; ++arcIdx)
assert(arcList[arcIdx].first == stateIdx);
}
return retStr;
}
上述代码主要将Graph表中元素存储到同的容器中,读者可以根据需求自己建立模型。
2.5 chart 容器
chart是Viterbi解码的核心解码图,其可以理解为格子图,其维度为:69*123,因为解码需要初始位置,故将68帧语音特征之前建立一帧作为起始位置,其中起始值为(-1, 0),其中-1表示弧号,0表示log似然值,其具体建立过程如下所示:
bool Lab2VitMain::init_utt() {
if (m_audioStrm.peek() == EOF) {
return false;
}
m_idStr = read_float_matrix(m_audioStrm, m_inAudio);
cout << "Processing utterance ID: " << m_idStr << endl;
m_frontEnd.get_feats(m_inAudio, m_feats);
if (m_feats.size2() != m_gmmSet.get_dim_count())
throw runtime_error("Mismatch in GMM and feat dim.");
if (m_doAlign) {
if (m_graphStrm.peek() == EOF)
throw runtime_error(
"Mismatch in number of audio files "
"and FSM's.");
m_graph.read(m_graphStrm, m_idStr);
}
if (m_graph.get_gmm_count() > m_gmmSet.get_gmm_count())
throw runtime_error(
"Mismatch in number of GMM's between "
"FSM and GmmSet.");
//m_gmmProbs矩阵维度为68*102,即为当前帧属于某个状态的pdf;
m_gmmSet.calc_gmm_probs(m_feats, m_gmmProbs);
m_chart.resize(m_feats.size1() + 1, m_graph.get_state_count());
m_chart.clear();
if (m_graph.get_start_state() < 0)
throw runtime_error("Graph has no start state.");
return true;
}
chart矩阵为什么对于语音识别解码至关重要,该矩阵每个元素保存的是弧号与至此状态最大的似然概率值,如果对chart容器有问题,可以加入微信解码群研究。
2.6 回溯
实际上chart图最后一帧即为最大似然概率对应的弧号,可以基于此回溯得到完成的弧号序列,弧号与状态之间存在映射关系,实际上每个弧号对应的状态的起始状态不就是前一帧的弧号对应的终止状态吗(这里大家可以仔细理解下)?,基于此可以以此得到状态最大似然概率对应的弧序列,进而映射为状态序列和音素序列,其模型具体建立过程如下所示:
double viterbi_backtrace(const Graph& graph, matrix& chart,
vector& outLabelList, bool doAlign) {
int frmCnt = chart.size1() - 1;
int stateCnt = chart.size2();
//finalStates存储终止状态对应的序号,且对其进行排序;
vector finalStates;
int finalCnt = graph.get_final_state_list(finalStates);
double bestLogProb = g_zeroLogProb;
int bestFinalState = -1;
for (int finalIdx = 0; finalIdx < finalCnt; ++finalIdx) {
int stateIdx = finalStates[finalIdx];
if (chart(frmCnt, stateIdx).get_log_prob() == g_zeroLogProb) continue;
//curLogProb表示终止状态对应的似然值与弧上概率的累加值;
//加上弧上概率是因为终止状态再进行log似然值累加时,终止状态上并未添加弧上概率;
double curLogProb = chart(frmCnt, stateIdx).get_log_prob() +
graph.get_final_log_prob(stateIdx);
if (curLogProb > bestLogProb)
bestLogProb = curLogProb, bestFinalState = stateIdx;
}
if (bestFinalState < 0) throw runtime_error("No complete paths found.");
outLabelList.clear();
int stateIdx = bestFinalState;
for (int frmIdx = frmCnt; --frmIdx >= 0;) {
assert((stateIdx >= 0) && (stateIdx < stateCnt));
int arcId = chart(frmIdx + 1, stateIdx).get_arc_id();
Arc arc;
graph.get_arc(arcId, arc);
assert((int)arc.get_dst_state() == stateIdx);
if (doAlign) {
throw runtime_error("Expect all arcs to have GMM.");
outLabelList.push_back(arc.get_gmm());
}
else if (arc.get_word() > 0) {
outLabelList.push_back(arc.get_word());
}
stateIdx = graph.get_src_state(arcId);
cout << stateIdx << endl;
}
if (stateIdx != graph.get_start_state())
throw runtime_error("Backtrace does not end at start state.");
reverse(outLabelList.begin(), outLabelList.end());
return bestLogProb;
}
至此Viterbi算法理论与实践简要介绍完毕。