一、简介
现在有越来越多的公司和团体开始使用chain model了,得益于kaldi社区日益活跃和kaldi作者povey的大力推荐,chain model的优越性在于:1,使用了单状态的biphone,建模粒度更大,有些类似于CTC;2,采用的低帧率策略,DNN每三帧输出一次,解码速度更快;3,使用了区分性训练,准确率更高;4,改进了MMI,提出了Lattice free MMI,训练速度更快。
二、源码解析
我们结合kaldi的源码,来逐步分析chain model的训练过程。
#include "chain/chain-numerator.h"
#include "cudamatrix/cu-vector.h"
namespace kaldi {
namespace chain {
//传入标签和前向计算的结果
NumeratorComputation::NumeratorComputation(
const Supervision &supervision,
const CuMatrixBase &nnet_output):
supervision_(supervision),
nnet_output_(nnet_output) {
//计算分子fst中每个状态对应的时间点
ComputeFstStateTimes(supervision_.fst, &fst_state_times_);
KALDI_ASSERT(supervision.num_sequences * supervision.frames_per_sequence ==
nnet_output.NumRows() &&
supervision.label_dim == nnet_output.NumCols());
}
void NumeratorComputation::ComputeLookupIndexes() {
int32 num_states = supervision_.fst.NumStates();
int32 num_arcs_guess = num_states * 2;
fst_output_indexes_.reserve(num_arcs_guess);
int32 frames_per_sequence = supervision_.frames_per_sequence,
num_sequences = supervision_.num_sequences,
cur_time = 0;
// the following is a CPU version of nnet_output_indexes_. It is a list of
// pairs (row-index, column-index) which index nnet_output_. The row-index
// corresponds to the time-frame 't', and the column-index the pdf-id, but we
// have to be a little careful with the row-index because there is a
// reordering that happens if supervision_.num_sequences > 1.
//
// output-index) and denominator_indexes_cpu is a list of pairs (c,
// history-state-index).
std::vector nnet_output_indexes_cpu;
// index_map_this_frame is a map, only valid for t == cur_time,
// from the pdf-id to the index into nnet_output_indexes_cpu for the
// likelihood at (cur_time, pdf-id).
unordered_map index_map_this_frame;
typedef unordered_map::iterator IterType;
for (int32 state = 0; state < num_states; state++) {
int32 t = fst_state_times_[state];
if (t != cur_time) {
KALDI_ASSERT(t == cur_time + 1);
index_map_this_frame.clear();
cur_time = t;
}
for (fst::ArcIterator aiter(supervision_.fst, state);
!aiter.Done(); aiter.Next()) {
int32 pdf_id = aiter.Value().ilabel - 1;
KALDI_ASSERT(pdf_id >= 0 && pdf_id < supervision_.label_dim);
int32 index = nnet_output_indexes_cpu.size();
// the next few lines are a more efficient way of doing the following:
// if (index_map_this_frame.count(pdf_id) == 0) {
// index = index_map_this_frame[pdf_id] = nnet_output_indexes_cpu.size();
// nnet_output_indexes_cpu.push_back(pair(pdf_id, row-index));
// } else {
// index = index_map_this_frame[pdf_id];
// }
std::pair p = index_map_this_frame.insert(
std::pair(pdf_id, index));
if (p.second) { // Was inserted -> map had no key 'output_index'
Int32Pair pair; // we can't use constructors as this was declared in C.
pair.first = ComputeRowIndex(t, frames_per_sequence, num_sequences);
pair.second = pdf_id;
nnet_output_indexes_cpu.push_back(pair);
} else { // was not inserted -> set 'index' to the existing index.
index = p.first->second;
}
//保存每个边在输出矩阵的位置,和nnet_output_indexes_对应
fst_output_indexes_.push_back(index);
}
}
//用来保存fst中每条边对应的时间点和pdf id
nnet_output_indexes_ = nnet_output_indexes_cpu;
KALDI_ASSERT(!fst_output_indexes_.empty());
}
//fst的前向计算
BaseFloat NumeratorComputation::Forward() {
ComputeLookupIndexes();
//计算每条边对应的概率
nnet_logprobs_.Resize(nnet_output_indexes_.Dim(), kUndefined);
nnet_output_.Lookup(nnet_output_indexes_, nnet_logprobs_.Data());
const fst::StdVectorFst &fst = supervision_.fst;
KALDI_ASSERT(fst.Start() == 0);
int32 num_states = fst.NumStates();
log_alpha_.Resize(num_states, kUndefined);
log_alpha_.Set(-std::numeric_limits::infinity());
tot_log_prob_ = -std::numeric_limits::infinity();
//前向结果保存
log_alpha_(0) = 0.0; // note, state zero is the start state, we checked above
const BaseFloat *nnet_logprob_data = nnet_logprobs_.Data();
std::vector::const_iterator fst_output_indexes_iter =
fst_output_indexes_.begin();
double *log_alpha_data = log_alpha_.Data();
for (int32 state = 0; state < num_states; state++) {
double this_log_alpha = log_alpha_data[state];
for (fst::ArcIterator aiter(fst, state); !aiter.Done();
aiter.Next(), ++fst_output_indexes_iter) {
const fst::StdArc &arc = aiter.Value();
int32 nextstate = arc.nextstate;
BaseFloat transition_logprob = -arc.weight.Value();
int32 index = *fst_output_indexes_iter;
BaseFloat pseudo_loglike = nnet_logprob_data[index];
double &next_log_alpha = log_alpha_data[nextstate];
//前向计算公式
next_log_alpha = LogAdd(next_log_alpha, pseudo_loglike +
transition_logprob + this_log_alpha);
}
//终止结点的处理
if (fst.Final(state) != fst::TropicalWeight::Zero()) {
BaseFloat final_logprob = -fst.Final(state).Value();
tot_log_prob_ = LogAdd(tot_log_prob_,
this_log_alpha + final_logprob);
}
}
KALDI_ASSERT(fst_output_indexes_iter ==
fst_output_indexes_.end());
return tot_log_prob_ * supervision_.weight;
}
//后向计算,必须在前向计算之后
void NumeratorComputation::Backward(
CuMatrixBase *nnet_output_deriv) {
const fst::StdVectorFst &fst = supervision_.fst;
int32 num_states = fst.NumStates();
log_beta_.Resize(num_states, kUndefined);
nnet_logprob_derivs_.Resize(nnet_logprobs_.Dim());
// we'll be counting backwards and moving the 'fst_output_indexes_iter'
// pointer back.
const int32 *fst_output_indexes_iter = &(fst_output_indexes_[0]) +
fst_output_indexes_.size();
const BaseFloat *nnet_logprob_data = nnet_logprobs_.Data();
double tot_log_prob = tot_log_prob_;
double *log_beta_data = log_beta_.Data();
const double *log_alpha_data = log_alpha_.Data();
BaseFloat *nnet_logprob_deriv_data = nnet_logprob_derivs_.Data();
for (int32 state = num_states - 1; state >= 0; state--) {
int32 this_num_arcs = fst.NumArcs(state);
// on the backward pass we access the fst_output_indexes_ vector in a zigzag
// pattern.
fst_output_indexes_iter -= this_num_arcs;
const int32 *this_fst_output_indexes_iter = fst_output_indexes_iter;
double this_log_beta = -fst.Final(state).Value();
double this_log_alpha = log_alpha_data[state];
for (fst::ArcIterator aiter(fst, state); !aiter.Done();
aiter.Next(), this_fst_output_indexes_iter++) {
const fst::StdArc &arc = aiter.Value();
double next_log_beta = log_beta_data[arc.nextstate];
BaseFloat transition_logprob = -arc.weight.Value();
int32 index = *this_fst_output_indexes_iter;
BaseFloat pseudo_loglike = nnet_logprob_data[index];
//后向计算公式
this_log_beta = LogAdd(this_log_beta, pseudo_loglike +
transition_logprob + next_log_beta);
//每条边在当前时间出现的比例
BaseFloat occupation_logprob = this_log_alpha + pseudo_loglike +
transition_logprob + next_log_beta - tot_log_prob,
//转化为概率
occupation_prob = exp(occupation_logprob);
nnet_logprob_deriv_data[index] += occupation_prob;
}
// check for -inf.
KALDI_PARANOID_ASSERT(this_log_beta - this_log_beta == 0);
log_beta_data[state] = this_log_beta;
}
KALDI_ASSERT(fst_output_indexes_iter == &(fst_output_indexes_[0]));
int32 start_state = 0; // the fact that the start state is numbered 0 is
// implied by other properties of the FST
// (epsilon-free-ness and topological sorting, and
// connectedness).
double tot_log_prob_backward = log_beta_(start_state);
if (!ApproxEqual(tot_log_prob_backward, tot_log_prob_))
KALDI_WARN << "Disagreement in forward/backward log-probs: "
<< tot_log_prob_backward << " vs. " << tot_log_prob_;
// copy this data to GPU.
CuVector nnet_logprob_deriv_cuda;
nnet_logprob_deriv_cuda.Swap(&nnet_logprob_derivs_);
//返回每条边上pdf对应出现的概率
nnet_output_deriv->AddElements(supervision_.weight, nnet_output_indexes_,
nnet_logprob_deriv_cuda.Data());
}
} // namespace chain
} // namespace kaldi
三、结论
以上就是chain model中分子fst对应的前后向计算过程。