hang: nnetbin/sat-nnet-train-frmshuff.cc注解2

循环主体(209-410行)

1.顺序读取特征,和相应的target

while(!feature_reader.Done){}
SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);

typedef SequentialTableReader > > SequrentialBaseFloatMatrixReader;

template class SequentialTableReader (util/kaldi-table.h 中276行)
成员变量:SequentialTableReaderImplBase *impl_;
SequentialTableReaderScriptImpl和SequentialTableReaderArchiveImpl继承SequentialTableReaderImplBase(在util/kaldi-table-inl.h)

template class KaldiObjectHolder{ typedef KaldiType T;}

2.randomizer中添加数据AddData函数

  • Randomizer中所有成员变量都是Cu形式的,也就是都在显存中
  • 在AddData函数中data_begin_被置为0;逐渐添加数据的过程调整data_end_
  • (1)isFull函数的结束条件要求data_end_超过randomizer_size,kaldi会读入完整的句子,所以实际大小可以略微超出randomizer_size
    (2)data_begin_在逐步从randomizer中读取数据后会增加,判断data_begin_!=0且data_end_>size是读取完minibatch的情况。当重新添加时data_begin_会被置为0
     IsFull() { return ((data_begin_ == 0) && (data_end_ > conf_.randomizer_size )); }
    
  • 最终一轮数据添加结束的时:data_中为具体数据;data_begin_=0;data_end_为结束位置。
  • 循环添加直至等于或刚超过randomizer_size

3.添加randomizer的顺序(209-310行)

  • 判断是否randomizer.IsFull()
  • utt=feature_reader.Key()
  • num_no_tgt_mat记录无target的量;num_other_error记录无frame_weights、keep_frames的量
  • 获取feature和target pair(weights默认为1.0)
    Matrix mat = feature_reader.Value();
    Posterior targets = targets_reader.Value(utt);
    weights.Resize(mat.NumRows()).Set(1.0);
    
  • 可能会处理某些长度的mismatch
  • 如果有拼帧等处理或者特征变换,利用的是nnet_transf.
    nnet_transf.Feedforward(CuMatrix(mat),&feats_transf);
    
  • 获取相应的spkid
    std::vector spkid;
    if (utt2spk != "") {
      if (map_utt2spk.find(utt) != map_utt2spk.end()) {
        spkid.resize(feats_transf.NumRows(), map_utt2spk[utt]);
      } else {
        KALDI_WARN << utt << ", spkid is unknown";
        continue;
      }
    } else {
      spkid.resize(feats_transf.NumRows(), 0);
    }
    
  • 最后向randomizer中加入数据,准备进行混合,这样算添加完1句,num_done用于计数,每5000句会报告一次速度
    KALDI_ASSERT(feats_transf.NumRows() == targets.size());
    feature_randomizer.AddData(feats_transf);
    targets_randomizer.AddData(targets);
    weights_randomizer.AddData(weights);
    spkids_randomizer.AddData(spkid);
    num_done++;
    

你可能感兴趣的:(hang: nnetbin/sat-nnet-train-frmshuff.cc注解2)