Kaldi中解码代码解析

解码就是输入音频,利用声学模型、构建好的WFST解码网络,输出最优状态序列的过程。以Kaldi中LatticeFasterOnlineDecoder为例,解析解码代码。
示例程序:
online2-wav-nnet3-latgen-faster --do-endpointing=false --online=false --frame-subsampling-factor=3
--config=conf/online.conf --max-active=7000 --beam=15.0 --frames-per-chunk=50 --lattice-beam=6.0
--acoustic-scale=1.0 --word-symbol-table=words.txt final.mdl HCLG.fst ark:spk2utt.txt scp:test.scp ark,t:lat.debug.txt
声学模型:final.mdl Kaldi Chain model 文件解析
WFST:HCLG.fst
spk2utt.txt 内容如下:
wav10 wav10
wav9 wav9
test.scp 内容如下:
wav10 data/wav/00030/2017_03_07_16.57.22_1175.wav
wav9 data/wav/00030/2017_03_07_16.57.40_2562.wav

主要数据结构:

  1. Token
struct Token {
    BaseFloat tot_cost;     // 到该状态的累计最优cost
    BaseFloat extra_cost;   //token所有ForwardLinks中和最优路径的cost差的最小值,PruneActiveTokens 用到
    ForwardLink *links;     // 链表,表示现在时刻到下一时刻的那条跳转边
    Token *next;            // 指向同一时刻的下一个token
    Token *backpointer;     // 指向上一时刻的最佳token,相当于一个回溯指针                            
};
  1. ForwardLink
struct ForwardLink {
    Token *next_tok;    // 这条链接指向的token
    Label ilabel;       // 这下面的四个量取自解码图中的跳转/弧/边,因为每一个状态
    Label olabel;       // 维护一个token,那么token到token之间的连接信息和状态到状态之间的信息
    BaseFloat graph_cost;       // 应该保持一致,所以会有输入(tid),输出,权值(就是graph_cost)
    BaseFloat acoustic_cost;    // acoustic_cost就是tid对应的pdf_id的在声学模型中的后验
    ForwardLink *next;          // 链表结构,指向下一个
};
  1. TokenList
struct TokenList {
  Token *toks;        // 同一时刻的token链表头
  bool must_prune_forward_links;  // 这两个是Lattice剪枝标记,起始默认设置为true
  bool must_prune_tokens;
};
  1. HashList
template class HashList {
   struct Elem {
    I key;  // state
    T val;  // Token
    Elem *tail;
  };
    struct HashBucket {
    size_t prev_bucket;  // 指向下一个桶,最后一个指向-1
    Elem *last_elem;  // 指向挂在桶上的最后一个元素,空桶指向NULL
  };

  Elem *list_head_;  // 链表头
  size_t bucket_list_tail_;  // 当前活跃桶最后一个下标
  size_t hash_size_;  // 当前活跃桶个数
  std::vector buckets_;  //存储实际活跃的桶
  Elem *freed_head_;  // head of list of currently freed elements. [ready for allocation]
  std::vector allocated_;  // list of allocated blocks.
};
HashList

解码过程中上述数据结构对应的一些重要变量如下(来自decoder/lattice-faster-online-decoder.h)

  HashList toks_;

  std::vector active_toks_; // 每一帧对应其中一个TokenList,等于frame+1,
  std::vector queue_;  // 临时变量,用于ProcessNonemitting,保存的是下一时刻state
  std::vector tmp_array_;  // used in GetCutoff.
 
解码主要数据结构

解码整体流程:

  1. 模型、文件加载,配置生成;
  2. 三层循环
for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {   //循环speaker
    ...
    const std::vector &uttlist = spk2utt_reader.Value();
    for (size_t i = 0; i < uttlist.size(); i++) {  //循环某个speaker的所有wav
        SingleUtteranceNnet3Decoder decoder(decoder_opts, trans_model, decodable_info,  *decode_fst, &feature_pipeline);       //构造函数中调用InitDecoding()                              
                                            
        //循环某个wav的chunk,比如说一帧一帧,online=false的时候一次加载整个wav
        while (samp_offset < data.Dim()) {  
            decoder.AdvanceDecoding();
        }
        decoder.FinalizeDecoding();

        decoder.GetLattice(end_of_utterance, &clat);
        GetDiagnosticsAndPrintOutput(utt, word_syms, clat,&num_frames, &tot_like);
                                     
    }  

}

对于单个wav,最主要流程就是三个函数:
void InitDecoding();
void LatticeFasterOnlineDecoder::AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames=-1)
void FinalizeDecoding();

其中AdvanceDecoding主流程如下图,每帧数据处理流程包括:

  • BaseFloat ProcessEmittingWrapper(DecodableInterface *decodable);
    实际调用LatticeFasterOnlineDecoder::ProcessEmitting>(decodable);
    处理输入非空跳转(ilabel != 0),主体两层循环,外层循环现在时刻所有Token,内层循环每个现在时刻的state能够跳转的下一时刻所有state。
    ProcessEmitting 函数中vector active_toks_ 加1(active_toks_.resize(active_toks_.size() + 1);),另外,NumFramesDecoded() 返回值等于active_toks_.size() - 1。

  • void ProcessNonemittingWrapper(BaseFloat cost_cutoff);
    实际调用LatticeFasterOnlineDecoder::ProcessNonemitting>(cost_cutoff);
    处理输入空跳转(ilabel == 0),主体两层循环,外层循环下一时刻所有Token,内层循环每个下一时刻的state能够跳转到的的state。可以这样理解,下一时刻的空跳转还是现在时刻通过一帧能够到达的时刻。

  • void PruneActiveTokens(BaseFloat delta);
    lattice beam 剪枝,默认25帧一次,包括两部分:剪枝ForwardLinks(PruneForwardLinks函数),剪枝Tokens(PruneTokensForFrame函数)

AdvanceDecoding 主流程
  1. 打印统计信息

主要函数解析:

  1. ProcessEmitting (decoder/lattice-faster-online-decoder.cc)
template 
BaseFloat LatticeFasterOnlineDecoder::ProcessEmitting(
    DecodableInterface *decodable) {
  KALDI_ASSERT(active_toks_.size() > 0);
  int32 frame = active_toks_.size() - 1; 
  active_toks_.resize(active_toks_.size() + 1); //每帧+1,外层调用的while循环也是

  Elem *final_toks = toks_.Clear(); // 此处clear的是bucket,返回链表头,遍历可得现在时刻所有state的链表
  
  Elem *best_elem = NULL;
  BaseFloat adaptive_beam;
  size_t tok_cnt;
  // Beam prune 参数获取,包括cur_cutoff,adaptive_beam, best_elem。 后两者用来确定next_cutoff
  // 主要是两个条件,默认是best_weight + config_.beam,同时用config_.max_active、config_.min_active 做了加强,希望state数目在[config_.min_active, config_.max_active]之间
  BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem);
  PossiblyResizeHash(tok_cnt);  // This makes sure the hash is always big enough.

  BaseFloat next_cutoff = std::numeric_limits::infinity();
  
  BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good
  
  const FstType &fst = dynamic_cast(fst_);

 // 下面这个块只是为了得到next_cutoff and cost_offset.
 // next_cutoff 用于下一时刻state的beam prune。等于现在时刻最优state到下一时刻对应所有state中最优的tot_cost
 // cost_offset 只是为了计算方面的考虑,相当于同时减了一个最小数。
  if (best_elem) {
    StateId state = best_elem->key;
    Token *tok = best_elem->val;
    cost_offset = - tok->tot_cost;
    for (fst::ArcIterator aiter(fst, state);
         !aiter.Done();
         aiter.Next()) {
      const Arc &arc = aiter.Value();
      if (arc.ilabel != 0) {  // propagate..
        BaseFloat new_weight = arc.weight.Value() + cost_offset -
            decodable->LogLikelihood(frame, arc.ilabel) + tok->tot_cost; // 这一步cost_offset + tok_tot_cost === 0,可以不要
        if (new_weight + adaptive_beam < next_cutoff)
          next_cutoff = new_weight + adaptive_beam;
      }
    }
  }
  ...

  // the tokens are now owned here, in final_toks, and the hash is empty.
  // 'owned' is a complex thing here; the point is we need to call DeleteElem
  // on each elem 'e' to let toks_ know we're done with them.
  for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) { //外层循环,遍历现在时刻state
    // loop this way because we delete "e" as we go.
    StateId state = e->key;
    Token *tok = e->val;
    if (tok->tot_cost <= cur_cutoff) {  // 现在时刻beam prune,tot_cost控制在cur_cutoff阈值以内,cur_cutoff=现在时刻最优state tot_cost+beam
      for (fst::ArcIterator aiter(fst, state);  // 内层循环,遍历现在时刻某个state的所有跳转
           !aiter.Done();
           aiter.Next()) {
        const Arc &arc = aiter.Value();
        if (arc.ilabel != 0) {  // 输入非空跳转
          BaseFloat ac_cost = cost_offset -
              decodable->LogLikelihood(frame, arc.ilabel),
              graph_cost = arc.weight.Value(),
              cur_cost = tok->tot_cost,
              tot_cost = cur_cost + ac_cost + graph_cost;
          if (tot_cost > next_cutoff) continue;  
          // 下一时刻beam prune,下一时刻tot_cost控制在阈值next_cutoff之内。
          // next_cutoff,初始值为:现在时刻最优state到下一时刻所有state中最优cost+adaptive_beam。注意不是下一时刻所有state中最优cost+adaptive_beam,后面再动态调整。
          else if (tot_cost + adaptive_beam < next_cutoff)
            next_cutoff = tot_cost + adaptive_beam; 
          
          //扩展下一时刻token,存取在toks_中,这一帧的ProcessNonemitting就是在toks_对应的list中循环。所以说ProcessNonemitting循环的是下一时刻的state以及下一时刻state的扩展跳转。
          Token *next_tok = FindOrAddToken(arc.nextstate,frame + 1, tot_cost, tok, NULL); 
          

          // 加边。Add ForwardLink from tok to next_tok (put on head of list tok->links)
          tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel,
                                       graph_cost, ac_cost, tok->links);
        }
      } // for all arcs
    }
    e_tail = e->tail;
    toks_.Delete(e); // delete Elem
  }
  return next_cutoff;
}

主体流程是双层循环,也就是Viterbi解码,外层循环现在时刻所有state,内层循环每个state对应的每个跳转,确定下一时刻所有state。过程中生成state对应的Token以及ForwardLink。同时用到了Beam Prune,现在时刻和下一时刻都有应用。

  1. ProcessNonemitting(BaseFloat cutoff) (decoder/lattice-faster-online-decoder.cc)
    首先遍历前面ProcessEmitting函数生成的HashList,得到现在时刻state 队列 queue_
    然后两层遍历:外层遍历queue_,内层遍历stata的空跳转;
    注意一点的是:frame = static_cast(active_toks_.size()) - 2 ,这个如果不注意,理解内循环中的FindOrAddToken函数会出现偏差。

  2. FindOrAddToken
    构造Token,插入到active_toks_[frame_plus_one].toks指向的Token list中,插入到HashList toks_中

inline LatticeFasterOnlineDecoder::Token *LatticeFasterOnlineDecoder::FindOrAddToken(
    StateId state, int32 frame_plus_one, BaseFloat tot_cost,
    Token *backpointer, bool *changed) {
  // Returns the Token pointer.  Sets "changed" (if non-NULL) to true
  // if the token was newly created or the cost changed.
  KALDI_ASSERT(frame_plus_one < active_toks_.size());
  Token *&toks = active_toks_[frame_plus_one].toks; // 引用,注意后面的改变其实改变了右边的值
  Elem *e_found = toks_.Find(state);  //HashList中查找
  if (e_found == NULL) {  // no such token presently.
    const BaseFloat extra_cost = 0.0;
    Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer); //构造Token,头插
    toks = new_tok;
    num_toks_++;
    toks_.Insert(state, new_tok); //toks_是一个HashList,ProcessNonemitting函数或者下一帧会用到
    if (changed) *changed = true;
    return new_tok;
  } else {
    Token *tok = e_found->val;  // There is an existing Token for this state.
    if (tok->tot_cost > tot_cost) {  // replace old token
      tok->tot_cost = tot_cost;
      tok->backpointer = backpointer;     
      if (changed) *changed = true;
    } else {
      if (changed) *changed = false;
    }
    return tok;
  }
}
  1. GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem)

Viterbi解码中涉及到现在时刻state数目以及下一时刻state数目,如果我们想要提高解码速度,需要对这两个数值都做缩减。实际做法是设置阈值,减少语音识别中现在时刻以及下一时刻状态数目,具体做法是:** 首先求现在时刻最优路径得分,加上beam,得到现在时刻得分阈值;然后求下一时刻最优路径得分,加上beam,得到下一时刻得分阈值**;具体步骤是:

  • 对所有状态排序,最优状态放最前面,最优状态得分=best_weight
  • 设置一个beam,设置阈值1=cur_cutoff,cur_cutoff=best_weight+beam,所有得分在cur_cutoff以内的,保留,反之丢弃,现在时刻的state数目减少。
  • 计算到下一时刻的最优路径得分new_weight。
  • 设置一个adaptive_beam, 设置阈值2=next_cutoff,next_cutoff=new_weight+adaptive_beam,所有得分在next_cutoff以内的,保留,反之丢弃,下一时刻的state数目减少。

注意上述步骤中的beam不是参数传递进去的config_.beam;因为我们如果直接用config_.beam,有可能卡出的state数目太多(大于config_.max_active)或者太少(少于config_.min_active)。所以需要分类讨论,确定最终的beam值,adaptive_beam类似。
cur_cutoff,adaptive_beam 都是来自GetCutoff函数:

// BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem);
// 输入final_toks,HashList对应的list,toks_.Clear() 操作后的得头结点指向
// 输出 cur_cutoff,返回值,用于现在时刻Beam Prune
// 输出 adaptive_beam, best_elem  得到next_cutoff,用于下一时刻Beam Prune
// 输出 tok_cnt  用于重置HashList toks_大小 ,足够大,减少内存分配时间

PossiblyResizeHash(tok_cnt)
BaseFloat LatticeFasterOnlineDecoder::GetCutoff(Elem *list_head, size_t *tok_count,
                                                  BaseFloat *adaptive_beam, Elem **best_elem) 
  1. PruneActiveTokens
    从后向前,主要做两步操作:
    PruneForwardLinks,删减Token的ForwordLinks,
    PruneTokensForFrame,删减Token本身,如果该Token对应的所有的ForwardLinks 都没有了,那Token本身也可以删除,判断条件tok->extra_cost == std::numeric_limits::infinity(),extra_cost代表该tok所有ForwardLinks到的next state 的tot_cost和到达该next state最优路径的tot_cost差的最小值,如果是无穷大(最小值都是无穷大)代表所有ForwordLinks都删除了。

Reference

http://www.funcwj.cn/2017/08/02/kaldi-online-decoder/
https://blog.csdn.net/u013677156/article/details/78930532

你可能感兴趣的:(Kaldi中解码代码解析)