练习基础代码(包括音频文件、音频文件读取代码、预加重代码、分帧加窗代码、快速傅里叶变换代码)可从Github中获取,链接如下:GitHub - nwpuaslp/ASR_Course。
本节理论笔记见:语音识别入门第五节:基于GMM-HMM的语音识别系统
lib_vit.c中需要添加代码如下:
int startState = graph.get_start_state(); //to get the start state
for (int frmIdx = 0; frmIdx < frmCnt; frmIdx++){ //frame cycle
for(int stateIdx = 0; stateIdx < stateCnt; stateIdx++) {//state cycle
if(frmIdx == 0){
chart(frmIdx, startState).assign(0, -1); //set the forward probability to 0 and Arc ID to -1
stateIdx = startState;
}
int arcCnt = graph.get_arc_count(stateIdx); // get the arc number
int arcId = graph.get_first_arc_id(stateIdx); //the frist arc id is -1
for(int arcIdx = 0; arcIdx < arcCnt; arcIdx++){ //arc cycle
int curArcId = arcId; //Current arc id=-1
Arc arc; //Arc class instantiation
arcId = graph.get_arc(arcId, arc);
int dstState = arc.get_dst_state(); //Jump to the next state
int gmmIndex = arc.get_gmm(); //get GMM ID
double logPorb = chart(frmIdx, stateIdx).get_log_prob() + gmmProbs(frmIdx, gmmIndex) + arc.get_log_prob();
if(logPorb > chart(frmIdx + 1, dstState).get_log_prob()){
chart(frmIdx + 1, dstState).assign(logPorb, curArcId);
}
}
if(frmIdx == 0)
break;
}
}
gmm_util.C中需要添加代码如下:
m_gaussCounts[gaussIdx] += posterior;
for(int dimIdx = 0; dimIdx < dimCnt; dimIdx++){
m_gaussStats1(gaussIdx, dimIdx) += posterior * feats[dimIdx];
m_gaussStats2(gaussIdx, dimIdx) += posterior * feats[dimIdx] * feats[dimIdx];
}
for(int gaussIdx = 0; gaussIdx < gaussCnt; gaussIdx++){
for(int dimIdx = 0; dimIdx < dimCnt; dimIdx++){
double newMean = m_gaussStats1(gaussIdx, dimIdx) / m_gaussCounts[gaussIdx];
double newVar = (m_gaussStats2(gaussIdx, dimIdx) - 2 * m_gaussStats1(gaussIdx, dimIdx) * newMean + newMean * newMean * m_gaussCounts[gaussIdx]) / m_gaussCounts[gaussIdx];
m_gmmSet.set_gaussian_mean(gaussIdx, dimIdx, newMean);
m_gmmSet.set_gaussian_var(gaussIdx, dimIdx, newVar);
}
}
lab2_fb.C中需要添加代码如下:
int startState = graph.get_start_state();
for(int frmIdx = 0; frmIdx < frmCnt; frmIdx++){
for(int stateIdx = 0; stateIdx < stateCnt; stateIdx++){
if(frmIdx == 0){
chart(frmIdx, startState).set_forw_log_prob(0);
}
int arcCnt = graph.get_arc_count(stateIdx);
int arcId = graph.get_first_arc_id(stateIdx);
for(int arcIdx = 0; arcIdx < arcCnt; arcIdx++){
Arc arc;
arcId = graph.get_arc(arcId, arc);
int dstState = arc.get_dst_state();
int gmmIndex = arc.get_gmm();
double logProb = chart(frmIdx, stateIdx).get_forw_log_prob() + gmmProbs(frmIdx, gmmIndex) + arc.get_log_prob();
vector logProbList = {logProb, chart(frmIdx + 1, dstState).get_forw_log_prob()};
logProb = add_log_probs(logProbList);
chart(frmIdx + 1, dstState).set_forw_log_prob(logProb);
}
}
}
for(int frmIdx = frmCnt; frmIdx > 0; --frmIdx){
for(int stateIdx = 0; stateIdx < stateCnt; ++ stateIdx){
int arcCnt = graph.get_arc_count(stateIdx);
int arcId = graph.get_first_arc_id(stateIdx);
for(int arcIdx = 0; arcIdx < arcCnt; ++arcIdx){
Arc arc;
arcId = graph.get_arc(arcId, arc);
int dstState = arc.get_dst_state();
double logProb = chart(frmIdx - 1, stateIdx).get_forw_log_prob() +
arc.get_log_prob() +
gmmProbs(frmIdx - 1, arc.get_gmm()) +
chart(frmIdx, dstState).get_back_log_prob();
gmmCountList.push_back(GmmCount(arc.get_gmm(), frmIdx - 1, exp(logProb - uttLogProb)));
}
}
}