Kaldi的MFCC特征提取代码分析

Kaldi的MFCC特征提取代码分析

本文转载自微信公众号:433的3号同学

make_mfcc.sh脚本
首先看顶层脚本make_mfcc.sh,地址:https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/make_mfcc.sh。使用方式如下:


./steps/make_mfcc.sh 
Usage: ./steps/make_mfcc.sh [options]  [ [] ]
 e.g.: ./steps/make_mfcc.sh data/train
Note:  defaults to /log, and
       defaults to /data.
Options:
  --mfcc-config           # config passed to compute-mfcc-feats.
  --nj                             # number of parallel jobs.
  --cmd > # how to run jobs.
  --write-utt2num-frames   # If true, write utt2num_frames file.
  --write-utt2dur          # If true, write utt2dur file.


 $cmd JOB=1:$nj $logdir/make_mfcc_${name}.JOB.log \
    compute-mfcc-feats $vtln_opts $write_utt2dur_opt --verbose=2 \
      --config=$mfcc_config scp,p:$logdir/wav_${name}.JOB.scp ark:- \| \
    copy-feats $write_num_frames_opt --compress=$compress ark:- \
      ark,scp:$mfccdir/raw_mfcc_$name.JOB.ark,$mfccdir/raw_mfcc_$name.JOB.scp \

上图为脚本里的核心函数:compute-mfcc-feats,使用方法及参数如下:


./compute-mfcc-feats 

Create MFCC feature files.
Usage:  compute-mfcc-feats [options...]  

Options:
  --allow-downsample          : If true, allow the input waveform to have a higher frequency than the specified --sample-frequency (and we'll downsample). (bool, default = false)
  --allow-upsample            : If true, allow the input waveform to have a lower frequency than the specified --sample-frequency (and we'll upsample). (bool, default = false)
  --blackman-coeff            : Constant coefficient for generalized Blackman window. (float, default = 0.42)
  --cepstral-lifter           : Constant that controls scaling of MFCCs (float, default = 22)
  --channel                   : Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (int, default = -1)
  --debug-mel                 : Print out debugging information for mel bin computation (bool, default = false)
  --dither                    : Dithering constant (0.0 means no dither). If you turn this off, you should set the --energy-floor option, e.g. to 1.0 or 0.1 (float, default = 1)
  --energy-floor              : Floor on energy (absolute, not relative) in MFCC computation. Only makes a difference if --use-energy=true; only necessary if --dither=0.0.  Suggested values: 0.1 or 1.0 (float, default = 0)
  --frame-length              : Frame length in milliseconds (float, default = 25)
  --frame-shift               : Frame shift in milliseconds (float, default = 10)
  --high-freq                 : High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (float, default = 0)
  --htk-compat                : If true, put energy or C0 last and use a factor of sqrt(2) on C0.  Warning: not sufficient to get HTK compatible features (need to change other parameters). (bool, default = false)
  --low-freq                  : Low cutoff frequency for mel bins (float, default = 20)
  --max-feature-vectors       : Memory optimization. If larger than 0, periodically remove feature vectors so that only this number of the latest feature vectors is retained. (int, default = -1)
  --min-duration              : Minimum duration of segments to process (in seconds). (float, default = 0)
  --num-ceps                  : Number of cepstra in MFCC computation (including C0) (int, default = 13)
  --num-mel-bins              : Number of triangular mel-frequency bins (int, default = 23)
  --output-format             : Format of the output files [kaldi, htk] (string, default = "kaldi")
  --preemphasis-coefficient   : Coefficient for use in signal preemphasis (float, default = 0.97)
  --raw-energy                : If true, compute energy before preemphasis and windowing (bool, default = true)
  --remove-dc-offset          : Subtract mean from waveform on each frame (bool, default = true)
  --round-to-power-of-two     : If true, round window size to power of two by zero-padding input to FFT. (bool, default = true)
  --sample-frequency          : Waveform data sample frequency (must match the waveform file, if specified there) (float, default = 16000)
  --snip-edges                : If true, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame-length.  If false, the number of frames depends only on the frame-shift, and we reflect the data at the ends. (bool, default = true)
  --subtract-mean             : Subtract mean of each feature file [CMS]; not recommended to do it this way.  (bool, default = false)
  --use-energy                : Use energy (not C0) in MFCC computation (bool, default = true)
  --utt2spk                   : Utterance to speaker-id map rspecifier (if doing VTLN and you have warps per speaker) (string, default = "")
  --vtln-high                 : High inflection point in piecewise linear VTLN warping function (if negative, offset from high-mel-freq (float, default = -500)
  --vtln-low                  : Low inflection point in piecewise linear VTLN warping function (float, default = 100)
  --vtln-map                  : Map from utterance or speaker-id to vtln warp factor (rspecifier) (string, default = "")
  --vtln-warp                 : Vtln warp factor (only applicable if vtln-map not specified) (float, default = 1)
  --window-type               : Type of window ("hamming"|"hanning"|"povey"|"rectangular"|"blackmann") (string, default = "povey")
  --write-utt2dur             : Wspecifier to write duration of each utterance in seconds, e.g. 'ark,t:utt2dur'. (string, default = "")

Standard options:
  --config                    : Configuration file to read (this option may be repeated) (string, default = "")
  --help                      : Print out usage message (bool, default = false)
  --print-args                : Print the command line arguments (to stderr) (bool, default = true)
  --verbose                   : Verbose level (higher->more logging) (int, default = 0)


很多参数,但是必须的参数只有两个:输入的wav-rspecifier和输出的feats-wspecifier。不了解rspecifier和wspecifier的读者请先阅读Kaldi文档解读和Kaldi I/O mechanisms。我们首先需要准备输入的文件:
在这里插入图片描述
compute-mfcc-feats只能读取WAV格式的数据,其它的格式需要转换成WAV格式。转换可以”离线”的方式提前用工具转好。

compute-mfcc-feats.cc


 int main(int argc, char *argv[]) {
   try {
     using namespace kaldi;
     const char *usage =
         "Create MFCC feature files.\n"
         "Usage:  compute-mfcc-feats [options...]  "
         "\n";
 
     // Construct all the global objects.
     ParseOptions po(usage);
     MfccOptions mfcc_opts;
     // Define defaults for global options.
     bool subtract_mean = false;
     BaseFloat vtln_warp = 1.0;
     std::string vtln_map_rspecifier;
     std::string utt2spk_rspecifier;
     int32 channel = -1;
     BaseFloat min_duration = 0.0;
     std::string output_format = "kaldi";
     std::string utt2dur_wspecifier;
 
     // Register the MFCC option struct.
     mfcc_opts.Register(&po);
 
     // Register the options.
     po.Register("output-format", &output_format, "Format of the output "
                 "files [kaldi, htk]");
     po.Register("subtract-mean", &subtract_mean, "Subtract mean of each "
                 "feature file [CMS]; not recommended to do it this way. ");
     po.Register("vtln-warp", &vtln_warp, "Vtln warp factor (only applicable "
                 "if vtln-map not specified)");
     po.Register("vtln-map", &vtln_map_rspecifier, "Map from utterance or "
                 "speaker-id to vtln warp factor (rspecifier)");
     po.Register("utt2spk", &utt2spk_rspecifier, "Utterance to speaker-id map "
                 "rspecifier (if doing VTLN and you have warps per speaker)");
     po.Register("channel", &channel, "Channel to extract (-1 -> expect mono, "
                 "0 -> left, 1 -> right)");
     po.Register("min-duration", &min_duration, "Minimum duration of segments "
                 "to process (in seconds).");
     po.Register("write-utt2dur", &utt2dur_wspecifier, "Wspecifier to write "
                 "duration of each utterance in seconds, e.g. 'ark,t:utt2dur'.");
 
     po.Read(argc, argv);
 
     if (po.NumArgs() != 2) {
       po.PrintUsage();
       exit(1);
     }
 
     std::string wav_rspecifier = po.GetArg(1);
 
     std::string output_wspecifier = po.GetArg(2);
 
     Mfcc mfcc(mfcc_opts);
 
     if (utt2spk_rspecifier != "" && vtln_map_rspecifier == "")
       KALDI_ERR << ("The --utt2spk option is only needed if "
                     "the --vtln-map option is used.");
     RandomAccessBaseFloatReaderMapped vtln_map_reader(vtln_map_rspecifier,
                                                       utt2spk_rspecifier);
 
     SequentialTableReader reader(wav_rspecifier);
     BaseFloatMatrixWriter kaldi_writer;  // typedef to TableWriter.
     TableWriter htk_writer;
 
     if (output_format == "kaldi") {
       if (!kaldi_writer.Open(output_wspecifier))
         KALDI_ERR << "Could not initialize output with wspecifier "
                   << output_wspecifier;
     } else if (output_format == "htk") {
       if (!htk_writer.Open(output_wspecifier))
         KALDI_ERR << "Could not initialize output with wspecifier "
                   << output_wspecifier;
     } else {
       KALDI_ERR << "Invalid output_format string " << output_format;
     }
 
     DoubleWriter utt2dur_writer(utt2dur_wspecifier);
 
     int32 num_utts = 0, num_success = 0;
     for (; !reader.Done(); reader.Next()) {
       num_utts++;
       std::string utt = reader.Key();
       const WaveData &wave_data = reader.Value();
       if (wave_data.Duration() < min_duration) {
         KALDI_WARN << "File: " << utt << " is too short ("
                    << wave_data.Duration() << " sec): producing no output.";
         continue;
       }
       int32 num_chan = wave_data.Data().NumRows(), this_chan = channel;
       {  // This block works out the channel (0=left, 1=right...)
         KALDI_ASSERT(num_chan > 0);  // should have been caught in
         // reading code if no channels.
         if (channel == -1) {
           this_chan = 0;
           if (num_chan != 1)
             KALDI_WARN << "Channel not specified but you have data with "
                        << num_chan  << " channels; defaulting to zero";
         } else {
           if (this_chan >= num_chan) {
             KALDI_WARN << "File with id " << utt << " has "
                        << num_chan << " channels but you specified channel "
                        << channel << ", producing no output.";
             continue;
           }
         }
       }
       BaseFloat vtln_warp_local;  // Work out VTLN warp factor.
       if (vtln_map_rspecifier != "") {
         if (!vtln_map_reader.HasKey(utt)) {
           KALDI_WARN << "No vtln-map entry for utterance-id (or speaker-id) "
                      << utt;
           continue;
         }
         vtln_warp_local = vtln_map_reader.Value(utt);
       } else {
         vtln_warp_local = vtln_warp;
       }
 
       SubVector waveform(wave_data.Data(), this_chan);
       Matrix features;
       try {
         mfcc.ComputeFeatures(waveform, wave_data.SampFreq(),
                              vtln_warp_local, &features);
       } catch (...) {
         KALDI_WARN << "Failed to compute features for utterance " << utt;
         continue;
       }
       if (subtract_mean) {
         Vector mean(features.NumCols());
         mean.AddRowSumMat(1.0, features);
         mean.Scale(1.0 / features.NumRows());
         for (int32 i = 0; i < features.NumRows(); i++)
           features.Row(i).AddVec(-1.0, mean);
       }
       if (output_format == "kaldi") {
         kaldi_writer.Write(utt, features);
       } else {
         std::pair, HtkHeader> p;
         p.first.Resize(features.NumRows(), features.NumCols());
         p.first.CopyFromMat(features);
         HtkHeader header = {
           features.NumRows(),
           100000,  // 10ms shift
           static_cast(sizeof(float)*(features.NumCols())),
           static_cast( 006 | // MFCC
           (mfcc_opts.use_energy ? 0100 : 020000)) // energy; otherwise c0
         };
         p.second = header;
         htk_writer.Write(utt, p);
       }
       if (utt2dur_writer.IsOpen()) {
         utt2dur_writer.Write(utt, wave_data.Duration());
       }
       if (num_utts % 10 == 0)
         KALDI_LOG << "Processed " << num_utts << " utterances";
       KALDI_VLOG(2) << "Processed features for key " << utt;
       num_success++;
     }
     KALDI_LOG << " Done " << num_success << " out of " << num_utts
               << " utterances.";
     return (num_success != 0 ? 0 : 1);
   } catch(const std::exception &e) {
     std::cerr << e.what();
     return -1;
   }
 }

前面的parse参数和选项我们可以略过,如果需要,”Mfcc mfcc(mfcc_opts);”,这是真正干活的。Mfcc是OfflineFeatureTpl模板类使用MfccComputer的typedef:

typedef OfflineFeatureTpl Mfcc;

对于OfflineFeatureTpl的官网描述如下:
Kaldi的MFCC特征提取代码分析_第1张图片
这个模板类用于离线特征提取,即。

在那里你可以在一开始就接触到整个信号。它主要存在于旧的(2016年前)MFCC、PLP等类的替换中,用于离线情况。2016年4月,我们重构了在线计算代码,以实现更大的模块性,并正确支持snip edges=false选项。

ComputeFeatures
我们来看计算特征的函数:


 void OfflineFeatureTpl::ComputeFeatures(
     const VectorBase &wave,
     BaseFloat sample_freq,
     BaseFloat vtln_warp,
     Matrix *output) {
   KALDI_ASSERT(output != NULL);
   BaseFloat new_sample_freq = computer_.GetFrameOptions().samp_freq;
   if (sample_freq == new_sample_freq) {
     Compute(wave, vtln_warp, output);
   } else {
     if (new_sample_freq < sample_freq &&
         ! computer_.GetFrameOptions().allow_downsample)
         KALDI_ERR << "Waveform and config sample Frequency mismatch: "
                   << sample_freq << " .vs " << new_sample_freq
                   << " (use --allow-downsample=true to allow "
                   << " downsampling the waveform).";
     else if (new_sample_freq > sample_freq &&
              ! computer_.GetFrameOptions().allow_upsample)
       KALDI_ERR << "Waveform and config sample Frequency mismatch: "
                   << sample_freq << " .vs " << new_sample_freq
                 << " (use --allow-upsample=true option to allow "
                 << " upsampling the waveform).";
     // Resample the waveform.
     Vector resampled_wave(wave);
     ResampleWaveform(sample_freq, wave,
                      new_sample_freq, &resampled_wave);
     Compute(resampled_wave, vtln_warp, output);
   }

它其实只是检查从WAV头部读取的采样率和compute-mfcc-feats传入的是否一致,如果一致使用Compute函数计算,否则如果运行的话对WAV文件进行上采样或者下采样以便满足compute-mfcc-feats的要求,最终还是调用Compute函数。

Compute


 void OfflineFeatureTpl::Compute(
     const VectorBase &wave,
     BaseFloat vtln_warp,
     Matrix *output) {
   KALDI_ASSERT(output != NULL);
   int32 rows_out = NumFrames(wave.Dim(), computer_.GetFrameOptions()),
       cols_out = computer_.Dim();
   if (rows_out == 0) {
     output->Resize(0, 0);
     return;
   }
   output->Resize(rows_out, cols_out);
   Vector window;  // windowed waveform.
   bool use_raw_log_energy = computer_.NeedRawLogEnergy();
   for (int32 r = 0; r < rows_out; r++) {  // r is frame index.
     BaseFloat raw_log_energy = 0.0;
     ExtractWindow(0, wave, r, computer_.GetFrameOptions(),
                   feature_window_function_, &window,
                   (use_raw_log_energy ? &raw_log_energy : NULL));
 
     SubVector output_row(*output, r);
     computer_.Compute(raw_log_energy, vtln_warp, &window, &output_row);
   }

首先使用NumFrames计算WAV有多少帧,然后遍历每一帧:使用ExtractWindow抽取每一帧,然后使用computer_.Compute提取特征。

NumFrames
目前默认的方式是snip_edges,也和HTK一致,也就是保证不需要padding,如果往后移动超出范围,那就不要了。


 int32 NumFrames(int64 num_samples,
                 const FrameExtractionOptions &opts,
                 bool flush) {
   int64 frame_shift = opts.WindowShift();
   int64 frame_length = opts.WindowSize();
   if (opts.snip_edges) {
     // with --snip-edges=true (the default), we use a HTK-like approach to
     // determining the number of frames-- all frames have to fit completely into
     // the waveform, and the first frame begins at sample zero.
     if (num_samples < frame_length)
       return 0;
     else
       return (1 + ((num_samples - frame_length) / frame_shift));
     // You can understand the expression above as follows: 'num_samples -
     // frame_length' is how much room we have to shift the frame within the
     // waveform; 'frame_shift' is how much we shift it each time; and the ratio
     // is how many times we can shift it (integer arithmetic rounds down).
   } else {
     // if --snip-edges=false, the number of frames is determined by rounding the
     // (file-length / frame-shift) to the nearest integer.  The point of this
     // formula is to make the number of frames an obvious and predictable
     // function of the frame shift and signal length, which makes many
     // segmentation-related questions simpler.
     //
     // Because integer division in C++ rounds toward zero, we add (half the
     // frame-shift minus epsilon) before dividing, to have the effect of
     // rounding towards the closest integer.
     int32 num_frames = (num_samples + (frame_shift / 2)) / frame_shift;
 
     if (flush)
       return num_frames;
 
     // note: 'end' always means the last plus one, i.e. one past the last.
     int64 end_sample_of_last_frame = FirstSampleOfFrame(num_frames - 1, opts)
         + frame_length;
 
     // the following code is optimized more for clarity than efficiency.
     // If flush == false, we can't output frames that extend past the end
     // of the signal.
     while (num_frames > 0 && end_sample_of_last_frame > num_samples) {
       num_frames--;
       end_sample_of_last_frame -= frame_shift;
     }
     return num_frames;
   }
 }

ExtractWindow
这里提取每一帧的时候会有一些小的trick,比如采样点的偏移、要求采样点是2的幂,否则在后面padding、在ProcessWindow里加pre-emphasis或者dithering(增加很小的随机噪声防止log为0)以及计算每一帧的能量。最后会把这些点乘以窗口函数FeatureWindowFunction(默认为povey自己设计的窗)


 void ExtractWindow(int64 sample_offset,
                    const VectorBase &wave,
                    int32 f,  // with 0 <= f < NumFrames(feats, opts)
                    const FrameExtractionOptions &opts,
                    const FeatureWindowFunction &window_function,
                    Vector *window,
                    BaseFloat *log_energy_pre_window) {
   KALDI_ASSERT(sample_offset >= 0 && wave.Dim() != 0);
   int32 frame_length = opts.WindowSize(),
       frame_length_padded = opts.PaddedWindowSize();
   int64 num_samples = sample_offset + wave.Dim(),
       start_sample = FirstSampleOfFrame(f, opts),
       end_sample = start_sample + frame_length;
 
   if (opts.snip_edges) {
     KALDI_ASSERT(start_sample >= sample_offset &&
                  end_sample <= num_samples);
   } else {
     KALDI_ASSERT(sample_offset == 0 || start_sample >= sample_offset);
   }
 
   if (window->Dim() != frame_length_padded)
     window->Resize(frame_length_padded, kUndefined);
 
   // wave_start and wave_end are start and end indexes into 'wave', for the
   // piece of wave that we're trying to extract.
   int32 wave_start = int32(start_sample - sample_offset),
       wave_end = wave_start + frame_length;
   if (wave_start >= 0 && wave_end <= wave.Dim()) {
     // the normal case-- no edge effects to consider.
     window->Range(0, frame_length).CopyFromVec(
         wave.Range(wave_start, frame_length));
   } else {
     // Deal with any end effects by reflection, if needed.  This code will only
     // be reached for about two frames per utterance, so we don't concern
     // ourselves excessively with efficiency.
     int32 wave_dim = wave.Dim();
     for (int32 s = 0; s < frame_length; s++) {
       int32 s_in_wave = s + wave_start;
       while (s_in_wave < 0 || s_in_wave >= wave_dim) {
         // reflect around the beginning or end of the wave.
         // e.g. -1 -> 0, -2 -> 1.
         // dim -> dim - 1, dim + 1 -> dim - 2.
         // the code supports repeated reflections, although this
         // would only be needed in pathological cases.
         if (s_in_wave < 0) s_in_wave = - s_in_wave - 1;
         else s_in_wave = 2 * wave_dim - 1 - s_in_wave;
       }
       (*window)(s) = wave(s_in_wave);
     }
   }
 
   if (frame_length_padded > frame_length)
     window->Range(frame_length, frame_length_padded - frame_length).SetZero();
 
   SubVector frame(*window, 0, frame_length);
 
   ProcessWindow(opts, window_function, &frame, log_energy_pre_window);
 }
 
 }  // namespace kaldi
 FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts) {
   int32 frame_length = opts.WindowSize();
   KALDI_ASSERT(frame_length > 0);
   window.Resize(frame_length);
   double a = M_2PI / (frame_length-1);
   for (int32 i = 0; i < frame_length; i++) {
     double i_fl = static_cast(i);
     if (opts.window_type == "hanning") {
       window(i) = 0.5  - 0.5*cos(a * i_fl);
     } else if (opts.window_type == "hamming") {
       window(i) = 0.54 - 0.46*cos(a * i_fl);
     } else if (opts.window_type == "povey") {  // like hamming but goes to zero at edges.
       window(i) = pow(0.5 - 0.5*cos(a * i_fl), 0.85);
     } else if (opts.window_type == "rectangular") {
       window(i) = 1.0;
     } else if (opts.window_type == "blackman") {
       window(i) = opts.blackman_coeff - 0.5*cos(a * i_fl) +
         (0.5 - opts.blackman_coeff) * cos(2 * a * i_fl);
     } else {
       KALDI_ERR << "Invalid window type " << opts.window_type;
     }
   }
 }

MfccComputer

最终到了干活的代码了,上面会调用它的Compute函数:


 void MfccComputer::Compute(BaseFloat signal_raw_log_energy,
                            BaseFloat vtln_warp,
                            VectorBase *signal_frame,
                            VectorBase *feature) {
   KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() &&
                feature->Dim() == this->Dim());
   // 获取Mel FilterBank,为了复用,会把每一个VLTN的alpha作为key存在map里。
   const MelBanks &mel_banks = *(GetMelBanks(vtln_warp));
 
   if (opts_.use_energy && !opts_.raw_energy)
     // 用向量向量乘法计算能量
     signal_raw_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame),
                                      std::numeric_limits::epsilon()));
  // FFT,默认是split-radix算法
   if (srfft_ != NULL)  // Compute FFT using the split-radix algorithm.
     srfft_->Compute(signal_frame->Data(), true);
   else  // An alternative algorithm that works for non-powers-of-two.
     RealFft(signal_frame, true);
    // FFT得到的复数计算其模得到功率谱
   // Convert the FFT into a power spectrum.
   ComputePowerSpectrum(signal_frame);
   SubVector power_spectrum(*signal_frame, 0,
                                       signal_frame->Dim() / 2 + 1);
   // 使用Filter bank滤波器组提取每个滤波器的能量
   mel_banks.Compute(power_spectrum, &mel_energies_);
  // 避免对零取log (如果有dithering那么不应该是零,但是dithering是可选的,所以保险一点还是要处理)
   // avoid log of zero (which should be prevented anyway by dithering).
   mel_energies_.ApplyFloor(std::numeric_limits::epsilon());
   mel_energies_.ApplyLog();  // take the log.
 
   feature->SetZero();  // in case there were NaNs.
   // 进行DCT得到倒谱
   // feature = dct_matrix_ * mel_energies [which now have log]
   feature->AddMatVec(1.0, dct_matrix_, kNoTrans, mel_energies_, 0.0);
   
   if (opts_.cepstral_lifter != 0.0)
     feature->MulElements(lifter_coeffs_);
   // 如果使用能量,那么把这一帧的能量替换掉倒谱的第一个系数。
   if (opts_.use_energy) {
     if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_)
       signal_raw_log_energy = log_energy_floor_;
     (*feature)(0) = signal_raw_log_energy;
   }
 
   if (opts_.htk_compat) {
     BaseFloat energy = (*feature)(0);
     for (int32 i = 0; i < opts_.num_ceps - 1; i++)
       (*feature)(i) = (*feature)(i+1);
     if (!opts_.use_energy)
       energy *= M_SQRT2;  // scale on C0 (actually removing a scale
     // we previously added that's part of one common definition of
     // the cosine transform.)
     (*feature)(opts_.num_ceps - 1)  = energy;
   }
 }

最终我们可以使用计算的函数compute-mfcc-feats.cc做一个简单的实验:
Kaldi的MFCC特征提取代码分析_第2张图片
Kaldi的MFCC特征提取代码分析_第3张图片

compute_cmvn_stats.sh
之后看顶层脚本compute_cmvn_stats.sh,地址:https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/compute_cmvn_stats.sh。使用方式如下:


steps/compute_cmvn_stats.sh 
Usage: steps/compute_cmvn_stats.sh [options]  [ [] ]
e.g.: steps/compute_cmvn_stats.sh data/train exp/make_mfcc/train mfcc
Note:  defaults to /log, and  defaults to /data
Options:
 --fake          gives you fake cmvn stats that do no normalization.
 --two-channel   is for two-channel telephone data, there must be no segments 
                 file and reco2file_and_channel must be present.  It will take
                 only frames that are louder than the other channel.
 --fake-dims   Generate stats that won't cause normalization for these
                  dimensions (e.g. 13:14:15)


compute-cmvn-stats --spk2utt=ark:$data/spk2utt scp:$data/feats.scp ark,scp:$cmvndir/cmvn_$name.ark,$cmvndir/cmvn_$name.scp \
    2> $logdir/cmvn_$name.log && echo "Error computing CMVN stats. See $logdir/cmvn_$name.log"

上图为脚本里的核心函数:compute-cmvn-stats,使用方法及参数如下:


compute-cmvn-stats 

Compute cepstral mean and variance normalization statistics
If wspecifier provided: per-utterance by default, or per-speaker if
spk2utt option provided; if wxfilename: global
Usage: compute-cmvn-stats  [options]  (|)
e.g.: compute-cmvn-stats --spk2utt=ark:data/train/spk2utt scp:data/train/feats.scp ark,scp:/foo/bar/cmvn.ark,data/train/cmvn.scp
See also: apply-cmvn, modify-cmvn-stats

Options:
  --binary                    : write in binary mode (applies only to global CMN/CVN) (bool, default = true)
  --spk2utt                   : rspecifier for speaker to utterance-list map (string, default = "")
  --weights                   : rspecifier for a vector of floats for each utterance, that's a per-frame weight. (string, default = "")

Standard options:
  --config                    : Configuration file to read (this option may be repeated) (string, default = "")
  --help                      : Print out usage message (bool, default = false)
  --print-args                : Print the command line arguments (to stderr) (bool, default = true)
  --verbose                   : Verbose level (higher->more logging) (int, default = 0)

很多参数,但是必须的参数只有两个:输入的feat-rspecifier和输出的stats-wspecifier。我们首先需要准备输入的文件,即上一步提取mfcc特征的输出文件raw_mfcc。

compute-cmvn-stats.cc


 int main(int argc, char *argv[]) {
   try {
     using namespace kaldi;
     using kaldi::int32;
 
     const char *usage =
         "Compute cepstral mean and variance normalization statistics\n"
         "If wspecifier provided: per-utterance by default, or per-speaker if\n"
         "spk2utt option provided; if wxfilename: global\n"
         "Usage: compute-cmvn-stats  [options]  (|)\n"
         "e.g.: compute-cmvn-stats --spk2utt=ark:data/train/spk2utt"
         " scp:data/train/feats.scp ark,scp:/foo/bar/cmvn.ark,data/train/cmvn.scp\n"
         "See also: apply-cmvn, modify-cmvn-stats\n";
 
     ParseOptions po(usage);
     std::string spk2utt_rspecifier, weights_rspecifier;
     bool binary = true;
     po.Register("spk2utt", &spk2utt_rspecifier, "rspecifier for speaker to utterance-list map");
     po.Register("binary", &binary, "write in binary mode (applies only to global CMN/CVN)");
     po.Register("weights", &weights_rspecifier, "rspecifier for a vector of floats "
                 "for each utterance, that's a per-frame weight.");
 
     po.Read(argc, argv);
 
     if (po.NumArgs() != 2) {
       po.PrintUsage();
       exit(1);
     }
 
     int32 num_done = 0, num_err = 0;
     std::string rspecifier = po.GetArg(1);
     std::string wspecifier_or_wxfilename = po.GetArg(2);
 
     RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier);
 
     if (ClassifyWspecifier(wspecifier_or_wxfilename, NULL, NULL, NULL)
         != kNoWspecifier) { // writing to a Table: per-speaker or per-utt CMN/CVN.
       std::string wspecifier = wspecifier_or_wxfilename;
 
       DoubleMatrixWriter writer(wspecifier);
 
       if (spk2utt_rspecifier != "") {
         SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
         RandomAccessBaseFloatMatrixReader feat_reader(rspecifier);
 
         for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
           std::string spk = spk2utt_reader.Key();
           const std::vector &uttlist = spk2utt_reader.Value();
           bool is_init = false;
           Matrix stats;
           for (size_t i = 0; i < uttlist.size(); i++) {
             std::string utt = uttlist[i];
             if (!feat_reader.HasKey(utt)) {
               KALDI_WARN << "Did not find features for utterance " << utt;
               num_err++;
               continue;
             }
             const Matrix &feats = feat_reader.Value(utt);
             if (!is_init) {
               InitCmvnStats(feats.NumCols(), &stats);
               is_init = true;
             }
             if (!AccCmvnStatsWrapper(utt, feats, &weights_reader, &stats)) {
               num_err++;
             } else {
               num_done++;
             }
           }
           if (stats.NumRows() == 0) {
             KALDI_WARN << "No stats accumulated for speaker " << spk;
           } else {
             writer.Write(spk, stats);
           }
         }
       } else {  // per-utterance normalization
         SequentialBaseFloatMatrixReader feat_reader(rspecifier);
 
         for (; !feat_reader.Done(); feat_reader.Next()) {
           std::string utt = feat_reader.Key();
           Matrix stats;
           const Matrix &feats = feat_reader.Value();
           InitCmvnStats(feats.NumCols(), &stats);
 
           if (!AccCmvnStatsWrapper(utt, feats, &weights_reader, &stats)) {
             num_err++;
             continue;
           }
           writer.Write(feat_reader.Key(), stats);
           num_done++;
         }
       }
     } else { // accumulate global stats
       if (spk2utt_rspecifier != "")
         KALDI_ERR << "--spk2utt option not compatible with wxfilename as output "
                    << "(did you forget ark:?)";
       std::string wxfilename = wspecifier_or_wxfilename;
       bool is_init = false;
       Matrix stats;
       SequentialBaseFloatMatrixReader feat_reader(rspecifier);
       for (; !feat_reader.Done(); feat_reader.Next()) {
         std::string utt = feat_reader.Key();
         const Matrix &feats = feat_reader.Value();
         if (!is_init) {
           InitCmvnStats(feats.NumCols(), &stats);
           is_init = true;
         }
         if (!AccCmvnStatsWrapper(utt, feats, &weights_reader, &stats)) {
           num_err++;
         } else {
           num_done++;
         }
       }
       Matrix stats_float(stats);
       WriteKaldiObject(stats_float, wxfilename, binary);
       KALDI_LOG << "Wrote global CMVN stats to "
                 << PrintableWxfilename(wxfilename);
     }
     KALDI_LOG << "Done accumulating CMVN stats for " << num_done
               << " utterances; " << num_err << " had errors.";
     return (num_done != 0 ? 0 : 1);
   } catch(const std::exception &e) {
     std::cerr << e.what();
     return -1;
   }
 }
根据每个说话人的音频集合进行AccCmvnStatsWrapper函数的处理。

AccCmvnStatsWrapper
 bool AccCmvnStatsWrapper(const std::string &utt,
                          const MatrixBase &feats,
                          RandomAccessBaseFloatVectorReader *weights_reader,
                          Matrix *cmvn_stats) {
   if (!weights_reader->IsOpen()) {
     AccCmvnStats(feats, NULL, cmvn_stats);
     return true;
   } else {
     if (!weights_reader->HasKey(utt)) {
       KALDI_WARN << "No weights available for utterance " << utt;
       return false;
     }
     const Vector &weights = weights_reader->Value(utt);
     if (weights.Dim() != feats.NumRows()) {
       KALDI_WARN << "Weights for utterance " << utt << " have wrong dimension "
                  << weights.Dim() << " vs. " << feats.NumRows();
       return false;
     }
     AccCmvnStats(feats, &weights, cmvn_stats);
     return true;
   }
 }

AccCmvnStatsWrapper主要是对一些特殊情况进行处理,主要干活的是AccCmvnStats。

AccCmvnStats


 void AccCmvnStats(const VectorBase &feats, BaseFloat weight, MatrixBase *stats) {
   int32 dim = feats.Dim();
   KALDI_ASSERT(stats != NULL);
   KALDI_ASSERT(stats->NumRows() == 2 && stats->NumCols() == dim + 1);
   // Remove these __restrict__ modifiers if they cause compilation problems.
   // It's just an optimization.
    double *__restrict__ mean_ptr = stats->RowData(0),
        *__restrict__ var_ptr = stats->RowData(1),
        *__restrict__ count_ptr = mean_ptr + dim;
    const BaseFloat * __restrict__ feats_ptr = feats.Data();
   *count_ptr += weight;
   // Careful-- if we change the format of the matrix, the "mean_ptr < count_ptr"
   // statement below might become wrong.
   for (; mean_ptr < count_ptr; mean_ptr++, var_ptr++, feats_ptr++) {
     *mean_ptr += *feats_ptr * weight;
     *var_ptr +=  *feats_ptr * *feats_ptr * weight;
   }
 }

这里将satate的第一行构造为传入feat的每一列的行加权和,第一行最后一维是feat的帧数。state的第二行是传入的feat的每一列的行加权平方和。

最终我们可以使用函数compute-mfcc-feats.cc和compute-cmvn-stats.cc做一个简单的实验:
Kaldi的MFCC特征提取代码分析_第4张图片

你可能感兴趣的:(分享)