语音识别系列7-chain model 之分子部分

一、简介

现在有越来越多的公司和团体开始使用chain model了,得益于kaldi社区日益活跃和kaldi作者povey的大力推荐,chain model的优越性在于:1,使用了单状态的biphone,建模粒度更大,有些类似于CTC;2,采用的低帧率策略,DNN每三帧输出一次,解码速度更快;3,使用了区分性训练,准确率更高;4,改进了MMI,提出了Lattice free MMI,训练速度更快。

二、源码解析

我们结合kaldi的源码,来逐步分析chain model的训练过程。



#include "chain/chain-numerator.h"
#include "cudamatrix/cu-vector.h"

namespace kaldi {
namespace chain {

//传入标签和前向计算的结果
NumeratorComputation::NumeratorComputation(
    const Supervision &supervision,
    const CuMatrixBase &nnet_output):
    supervision_(supervision),
    nnet_output_(nnet_output) {
  //计算分子fst中每个状态对应的时间点
  ComputeFstStateTimes(supervision_.fst, &fst_state_times_);
  KALDI_ASSERT(supervision.num_sequences * supervision.frames_per_sequence ==
               nnet_output.NumRows() &&
               supervision.label_dim == nnet_output.NumCols());
}


void NumeratorComputation::ComputeLookupIndexes() {

  int32 num_states = supervision_.fst.NumStates();
  int32 num_arcs_guess = num_states * 2;
  fst_output_indexes_.reserve(num_arcs_guess);

  int32 frames_per_sequence = supervision_.frames_per_sequence,
      num_sequences = supervision_.num_sequences,
      cur_time = 0;

  // the following is a CPU version of nnet_output_indexes_.  It is a list of
  // pairs (row-index, column-index) which index nnet_output_.  The row-index
  // corresponds to the time-frame 't', and the column-index the pdf-id, but we
  // have to be a little careful with the row-index because there is a
  // reordering that happens if supervision_.num_sequences > 1.
  //

  // output-index) and denominator_indexes_cpu is a list of pairs (c,
  // history-state-index).
  std::vector nnet_output_indexes_cpu;

  // index_map_this_frame is a map, only valid for t == cur_time,
  // from the pdf-id to the index into nnet_output_indexes_cpu for the
  // likelihood at (cur_time, pdf-id).
  unordered_map index_map_this_frame;

  typedef unordered_map::iterator IterType;

  for (int32 state = 0; state < num_states; state++) {
    int32 t = fst_state_times_[state];
    if (t != cur_time) {
      KALDI_ASSERT(t == cur_time + 1);
      index_map_this_frame.clear();
      cur_time = t;
    }
    for (fst::ArcIterator aiter(supervision_.fst, state);
         !aiter.Done(); aiter.Next()) {
      int32 pdf_id = aiter.Value().ilabel - 1;
      KALDI_ASSERT(pdf_id >= 0 && pdf_id < supervision_.label_dim);

      int32 index = nnet_output_indexes_cpu.size();

      // the next few lines are a more efficient way of doing the following:
      // if (index_map_this_frame.count(pdf_id) == 0) {
      //   index = index_map_this_frame[pdf_id] = nnet_output_indexes_cpu.size();
      //   nnet_output_indexes_cpu.push_back(pair(pdf_id, row-index));
      // } else {
      //   index = index_map_this_frame[pdf_id];
      // }
      std::pair p = index_map_this_frame.insert(
          std::pair(pdf_id, index));
      if (p.second) {  // Was inserted -> map had no key 'output_index'
        Int32Pair pair;  // we can't use constructors as this was declared in C.
        pair.first = ComputeRowIndex(t, frames_per_sequence, num_sequences);
        pair.second = pdf_id;
        nnet_output_indexes_cpu.push_back(pair);
      } else {  // was not inserted -> set 'index' to the existing index.
        index = p.first->second;
      }
	  //保存每个边在输出矩阵的位置,和nnet_output_indexes_对应
      fst_output_indexes_.push_back(index);
    }
  }
  //用来保存fst中每条边对应的时间点和pdf id
  nnet_output_indexes_ = nnet_output_indexes_cpu;
  KALDI_ASSERT(!fst_output_indexes_.empty());
}

//fst的前向计算
BaseFloat NumeratorComputation::Forward() {
  ComputeLookupIndexes();
  //计算每条边对应的概率
  nnet_logprobs_.Resize(nnet_output_indexes_.Dim(), kUndefined);
  nnet_output_.Lookup(nnet_output_indexes_, nnet_logprobs_.Data());
  const fst::StdVectorFst &fst = supervision_.fst;
  KALDI_ASSERT(fst.Start() == 0);
  int32 num_states = fst.NumStates();
  log_alpha_.Resize(num_states, kUndefined);
  log_alpha_.Set(-std::numeric_limits::infinity());
  tot_log_prob_ = -std::numeric_limits::infinity();
  //前向结果保存
  log_alpha_(0) = 0.0;  // note, state zero is the start state, we checked above

  const BaseFloat *nnet_logprob_data = nnet_logprobs_.Data();
  std::vector::const_iterator fst_output_indexes_iter =
      fst_output_indexes_.begin();

  double *log_alpha_data = log_alpha_.Data();

  for (int32 state = 0; state < num_states; state++) {
    double this_log_alpha = log_alpha_data[state];
    for (fst::ArcIterator aiter(fst, state); !aiter.Done();
         aiter.Next(), ++fst_output_indexes_iter) {
      const fst::StdArc &arc = aiter.Value();
      int32 nextstate = arc.nextstate;
      BaseFloat transition_logprob = -arc.weight.Value();
      int32 index = *fst_output_indexes_iter;
      BaseFloat pseudo_loglike = nnet_logprob_data[index];
      double &next_log_alpha = log_alpha_data[nextstate];
	  //前向计算公式
      next_log_alpha = LogAdd(next_log_alpha, pseudo_loglike +
                              transition_logprob + this_log_alpha);
    }
	//终止结点的处理
    if (fst.Final(state) != fst::TropicalWeight::Zero()) {
      BaseFloat final_logprob = -fst.Final(state).Value();
      tot_log_prob_ = LogAdd(tot_log_prob_,
                             this_log_alpha + final_logprob);
    }
  }
  KALDI_ASSERT(fst_output_indexes_iter ==
               fst_output_indexes_.end());
  return tot_log_prob_ * supervision_.weight;
}

//后向计算,必须在前向计算之后
void NumeratorComputation::Backward(
    CuMatrixBase *nnet_output_deriv) {
  const fst::StdVectorFst &fst = supervision_.fst;
  int32 num_states = fst.NumStates();
  log_beta_.Resize(num_states, kUndefined);
  nnet_logprob_derivs_.Resize(nnet_logprobs_.Dim());

  // we'll be counting backwards and moving the 'fst_output_indexes_iter'
  // pointer back.
  const int32 *fst_output_indexes_iter = &(fst_output_indexes_[0]) +
      fst_output_indexes_.size();
  const BaseFloat *nnet_logprob_data = nnet_logprobs_.Data();
  double tot_log_prob = tot_log_prob_;
  double *log_beta_data = log_beta_.Data();
  const double *log_alpha_data = log_alpha_.Data();
  BaseFloat *nnet_logprob_deriv_data = nnet_logprob_derivs_.Data();

  for (int32 state = num_states - 1; state >= 0; state--) {
    int32 this_num_arcs  = fst.NumArcs(state);
    // on the backward pass we access the fst_output_indexes_ vector in a zigzag
    // pattern.
    fst_output_indexes_iter -= this_num_arcs;
    const int32 *this_fst_output_indexes_iter = fst_output_indexes_iter;
    double this_log_beta = -fst.Final(state).Value();
    double this_log_alpha = log_alpha_data[state];
    for (fst::ArcIterator aiter(fst, state); !aiter.Done();
         aiter.Next(), this_fst_output_indexes_iter++) {
      const fst::StdArc &arc = aiter.Value();
      double next_log_beta = log_beta_data[arc.nextstate];
      BaseFloat transition_logprob = -arc.weight.Value();
      int32 index = *this_fst_output_indexes_iter;
      BaseFloat pseudo_loglike = nnet_logprob_data[index];
	  //后向计算公式
      this_log_beta = LogAdd(this_log_beta, pseudo_loglike +
                             transition_logprob + next_log_beta);
	  //每条边在当前时间出现的比例
      BaseFloat occupation_logprob = this_log_alpha + pseudo_loglike +
          transition_logprob + next_log_beta - tot_log_prob,
		  //转化为概率
          occupation_prob = exp(occupation_logprob);
      nnet_logprob_deriv_data[index] += occupation_prob;
    }
    // check for -inf.
    KALDI_PARANOID_ASSERT(this_log_beta - this_log_beta == 0);
    log_beta_data[state] = this_log_beta;
  }
  KALDI_ASSERT(fst_output_indexes_iter == &(fst_output_indexes_[0]));

  int32 start_state = 0;  // the fact that the start state is numbered 0 is
                          // implied by other properties of the FST
                          // (epsilon-free-ness and topological sorting, and
                          // connectedness).
  double tot_log_prob_backward = log_beta_(start_state);
  if (!ApproxEqual(tot_log_prob_backward, tot_log_prob_))
    KALDI_WARN << "Disagreement in forward/backward log-probs: "
               << tot_log_prob_backward << " vs. " << tot_log_prob_;

  // copy this data to GPU.
  CuVector nnet_logprob_deriv_cuda;
  nnet_logprob_deriv_cuda.Swap(&nnet_logprob_derivs_);
  //返回每条边上pdf对应出现的概率
  nnet_output_deriv->AddElements(supervision_.weight, nnet_output_indexes_,
                                 nnet_logprob_deriv_cuda.Data());
}


}  // namespace chain
}  // namespace kaldi

三、结论

以上就是chain model中分子fst对应的前后向计算过程。

你可能感兴趣的:(c++,kaldi,ctc,asr)