Kaldi的MFCC特征提取代码分析
本文转载自微信公众号:433的3号同学
make_mfcc.sh脚本
首先看顶层脚本make_mfcc.sh,地址:https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/make_mfcc.sh。使用方式如下:
./steps/make_mfcc.sh
Usage: ./steps/make_mfcc.sh [options] [ [] ]
e.g.: ./steps/make_mfcc.sh data/train
Note: defaults to /log, and
defaults to /data.
Options:
--mfcc-config # config passed to compute-mfcc-feats.
--nj # number of parallel jobs.
--cmd > # how to run jobs.
--write-utt2num-frames # If true, write utt2num_frames file.
--write-utt2dur # If true, write utt2dur file.
$cmd JOB=1:$nj $logdir/make_mfcc_${name}.JOB.log \
compute-mfcc-feats $vtln_opts $write_utt2dur_opt --verbose=2 \
--config=$mfcc_config scp,p:$logdir/wav_${name}.JOB.scp ark:- \| \
copy-feats $write_num_frames_opt --compress=$compress ark:- \
ark,scp:$mfccdir/raw_mfcc_$name.JOB.ark,$mfccdir/raw_mfcc_$name.JOB.scp \
上图为脚本里的核心函数:compute-mfcc-feats,使用方法及参数如下:
./compute-mfcc-feats
Create MFCC feature files.
Usage: compute-mfcc-feats [options...]
Options:
--allow-downsample : If true, allow the input waveform to have a higher frequency than the specified --sample-frequency (and we'll downsample). (bool, default = false)
--allow-upsample : If true, allow the input waveform to have a lower frequency than the specified --sample-frequency (and we'll upsample). (bool, default = false)
--blackman-coeff : Constant coefficient for generalized Blackman window. (float, default = 0.42)
--cepstral-lifter : Constant that controls scaling of MFCCs (float, default = 22)
--channel : Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (int, default = -1)
--debug-mel : Print out debugging information for mel bin computation (bool, default = false)
--dither : Dithering constant (0.0 means no dither). If you turn this off, you should set the --energy-floor option, e.g. to 1.0 or 0.1 (float, default = 1)
--energy-floor : Floor on energy (absolute, not relative) in MFCC computation. Only makes a difference if --use-energy=true; only necessary if --dither=0.0. Suggested values: 0.1 or 1.0 (float, default = 0)
--frame-length : Frame length in milliseconds (float, default = 25)
--frame-shift : Frame shift in milliseconds (float, default = 10)
--high-freq : High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (float, default = 0)
--htk-compat : If true, put energy or C0 last and use a factor of sqrt(2) on C0. Warning: not sufficient to get HTK compatible features (need to change other parameters). (bool, default = false)
--low-freq : Low cutoff frequency for mel bins (float, default = 20)
--max-feature-vectors : Memory optimization. If larger than 0, periodically remove feature vectors so that only this number of the latest feature vectors is retained. (int, default = -1)
--min-duration : Minimum duration of segments to process (in seconds). (float, default = 0)
--num-ceps : Number of cepstra in MFCC computation (including C0) (int, default = 13)
--num-mel-bins : Number of triangular mel-frequency bins (int, default = 23)
--output-format : Format of the output files [kaldi, htk] (string, default = "kaldi")
--preemphasis-coefficient : Coefficient for use in signal preemphasis (float, default = 0.97)
--raw-energy : If true, compute energy before preemphasis and windowing (bool, default = true)
--remove-dc-offset : Subtract mean from waveform on each frame (bool, default = true)
--round-to-power-of-two : If true, round window size to power of two by zero-padding input to FFT. (bool, default = true)
--sample-frequency : Waveform data sample frequency (must match the waveform file, if specified there) (float, default = 16000)
--snip-edges : If true, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame-length. If false, the number of frames depends only on the frame-shift, and we reflect the data at the ends. (bool, default = true)
--subtract-mean : Subtract mean of each feature file [CMS]; not recommended to do it this way. (bool, default = false)
--use-energy : Use energy (not C0) in MFCC computation (bool, default = true)
--utt2spk : Utterance to speaker-id map rspecifier (if doing VTLN and you have warps per speaker) (string, default = "")
--vtln-high : High inflection point in piecewise linear VTLN warping function (if negative, offset from high-mel-freq (float, default = -500)
--vtln-low : Low inflection point in piecewise linear VTLN warping function (float, default = 100)
--vtln-map : Map from utterance or speaker-id to vtln warp factor (rspecifier) (string, default = "")
--vtln-warp : Vtln warp factor (only applicable if vtln-map not specified) (float, default = 1)
--window-type : Type of window ("hamming"|"hanning"|"povey"|"rectangular"|"blackmann") (string, default = "povey")
--write-utt2dur : Wspecifier to write duration of each utterance in seconds, e.g. 'ark,t:utt2dur'. (string, default = "")
Standard options:
--config : Configuration file to read (this option may be repeated) (string, default = "")
--help : Print out usage message (bool, default = false)
--print-args : Print the command line arguments (to stderr) (bool, default = true)
--verbose : Verbose level (higher->more logging) (int, default = 0)
很多参数,但是必须的参数只有两个:输入的wav-rspecifier和输出的feats-wspecifier。不了解rspecifier和wspecifier的读者请先阅读Kaldi文档解读和Kaldi I/O mechanisms。我们首先需要准备输入的文件:
compute-mfcc-feats只能读取WAV格式的数据,其它的格式需要转换成WAV格式。转换可以”离线”的方式提前用工具转好。
compute-mfcc-feats.cc
int main(int argc, char *argv[]) {
try {
using namespace kaldi;
const char *usage =
"Create MFCC feature files.\n"
"Usage: compute-mfcc-feats [options...] "
"\n";
// Construct all the global objects.
ParseOptions po(usage);
MfccOptions mfcc_opts;
// Define defaults for global options.
bool subtract_mean = false;
BaseFloat vtln_warp = 1.0;
std::string vtln_map_rspecifier;
std::string utt2spk_rspecifier;
int32 channel = -1;
BaseFloat min_duration = 0.0;
std::string output_format = "kaldi";
std::string utt2dur_wspecifier;
// Register the MFCC option struct.
mfcc_opts.Register(&po);
// Register the options.
po.Register("output-format", &output_format, "Format of the output "
"files [kaldi, htk]");
po.Register("subtract-mean", &subtract_mean, "Subtract mean of each "
"feature file [CMS]; not recommended to do it this way. ");
po.Register("vtln-warp", &vtln_warp, "Vtln warp factor (only applicable "
"if vtln-map not specified)");
po.Register("vtln-map", &vtln_map_rspecifier, "Map from utterance or "
"speaker-id to vtln warp factor (rspecifier)");
po.Register("utt2spk", &utt2spk_rspecifier, "Utterance to speaker-id map "
"rspecifier (if doing VTLN and you have warps per speaker)");
po.Register("channel", &channel, "Channel to extract (-1 -> expect mono, "
"0 -> left, 1 -> right)");
po.Register("min-duration", &min_duration, "Minimum duration of segments "
"to process (in seconds).");
po.Register("write-utt2dur", &utt2dur_wspecifier, "Wspecifier to write "
"duration of each utterance in seconds, e.g. 'ark,t:utt2dur'.");
po.Read(argc, argv);
if (po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
}
std::string wav_rspecifier = po.GetArg(1);
std::string output_wspecifier = po.GetArg(2);
Mfcc mfcc(mfcc_opts);
if (utt2spk_rspecifier != "" && vtln_map_rspecifier == "")
KALDI_ERR << ("The --utt2spk option is only needed if "
"the --vtln-map option is used.");
RandomAccessBaseFloatReaderMapped vtln_map_reader(vtln_map_rspecifier,
utt2spk_rspecifier);
SequentialTableReader reader(wav_rspecifier);
BaseFloatMatrixWriter kaldi_writer; // typedef to TableWriter.
TableWriter htk_writer;
if (output_format == "kaldi") {
if (!kaldi_writer.Open(output_wspecifier))
KALDI_ERR << "Could not initialize output with wspecifier "
<< output_wspecifier;
} else if (output_format == "htk") {
if (!htk_writer.Open(output_wspecifier))
KALDI_ERR << "Could not initialize output with wspecifier "
<< output_wspecifier;
} else {
KALDI_ERR << "Invalid output_format string " << output_format;
}
DoubleWriter utt2dur_writer(utt2dur_wspecifier);
int32 num_utts = 0, num_success = 0;
for (; !reader.Done(); reader.Next()) {
num_utts++;
std::string utt = reader.Key();
const WaveData &wave_data = reader.Value();
if (wave_data.Duration() < min_duration) {
KALDI_WARN << "File: " << utt << " is too short ("
<< wave_data.Duration() << " sec): producing no output.";
continue;
}
int32 num_chan = wave_data.Data().NumRows(), this_chan = channel;
{ // This block works out the channel (0=left, 1=right...)
KALDI_ASSERT(num_chan > 0); // should have been caught in
// reading code if no channels.
if (channel == -1) {
this_chan = 0;
if (num_chan != 1)
KALDI_WARN << "Channel not specified but you have data with "
<< num_chan << " channels; defaulting to zero";
} else {
if (this_chan >= num_chan) {
KALDI_WARN << "File with id " << utt << " has "
<< num_chan << " channels but you specified channel "
<< channel << ", producing no output.";
continue;
}
}
}
BaseFloat vtln_warp_local; // Work out VTLN warp factor.
if (vtln_map_rspecifier != "") {
if (!vtln_map_reader.HasKey(utt)) {
KALDI_WARN << "No vtln-map entry for utterance-id (or speaker-id) "
<< utt;
continue;
}
vtln_warp_local = vtln_map_reader.Value(utt);
} else {
vtln_warp_local = vtln_warp;
}
SubVector waveform(wave_data.Data(), this_chan);
Matrix features;
try {
mfcc.ComputeFeatures(waveform, wave_data.SampFreq(),
vtln_warp_local, &features);
} catch (...) {
KALDI_WARN << "Failed to compute features for utterance " << utt;
continue;
}
if (subtract_mean) {
Vector mean(features.NumCols());
mean.AddRowSumMat(1.0, features);
mean.Scale(1.0 / features.NumRows());
for (int32 i = 0; i < features.NumRows(); i++)
features.Row(i).AddVec(-1.0, mean);
}
if (output_format == "kaldi") {
kaldi_writer.Write(utt, features);
} else {
std::pair, HtkHeader> p;
p.first.Resize(features.NumRows(), features.NumCols());
p.first.CopyFromMat(features);
HtkHeader header = {
features.NumRows(),
100000, // 10ms shift
static_cast(sizeof(float)*(features.NumCols())),
static_cast( 006 | // MFCC
(mfcc_opts.use_energy ? 0100 : 020000)) // energy; otherwise c0
};
p.second = header;
htk_writer.Write(utt, p);
}
if (utt2dur_writer.IsOpen()) {
utt2dur_writer.Write(utt, wave_data.Duration());
}
if (num_utts % 10 == 0)
KALDI_LOG << "Processed " << num_utts << " utterances";
KALDI_VLOG(2) << "Processed features for key " << utt;
num_success++;
}
KALDI_LOG << " Done " << num_success << " out of " << num_utts
<< " utterances.";
return (num_success != 0 ? 0 : 1);
} catch(const std::exception &e) {
std::cerr << e.what();
return -1;
}
}
前面的parse参数和选项我们可以略过,如果需要,”Mfcc mfcc(mfcc_opts);”,这是真正干活的。Mfcc是OfflineFeatureTpl模板类使用MfccComputer的typedef:
typedef OfflineFeatureTpl Mfcc;
对于OfflineFeatureTpl的官网描述如下:
这个模板类用于离线特征提取,即。
在那里你可以在一开始就接触到整个信号。它主要存在于旧的(2016年前)MFCC、PLP等类的替换中,用于离线情况。2016年4月,我们重构了在线计算代码,以实现更大的模块性,并正确支持snip edges=false选项。
ComputeFeatures
我们来看计算特征的函数:
void OfflineFeatureTpl::ComputeFeatures(
const VectorBase &wave,
BaseFloat sample_freq,
BaseFloat vtln_warp,
Matrix *output) {
KALDI_ASSERT(output != NULL);
BaseFloat new_sample_freq = computer_.GetFrameOptions().samp_freq;
if (sample_freq == new_sample_freq) {
Compute(wave, vtln_warp, output);
} else {
if (new_sample_freq < sample_freq &&
! computer_.GetFrameOptions().allow_downsample)
KALDI_ERR << "Waveform and config sample Frequency mismatch: "
<< sample_freq << " .vs " << new_sample_freq
<< " (use --allow-downsample=true to allow "
<< " downsampling the waveform).";
else if (new_sample_freq > sample_freq &&
! computer_.GetFrameOptions().allow_upsample)
KALDI_ERR << "Waveform and config sample Frequency mismatch: "
<< sample_freq << " .vs " << new_sample_freq
<< " (use --allow-upsample=true option to allow "
<< " upsampling the waveform).";
// Resample the waveform.
Vector resampled_wave(wave);
ResampleWaveform(sample_freq, wave,
new_sample_freq, &resampled_wave);
Compute(resampled_wave, vtln_warp, output);
}
它其实只是检查从WAV头部读取的采样率和compute-mfcc-feats传入的是否一致,如果一致使用Compute函数计算,否则如果运行的话对WAV文件进行上采样或者下采样以便满足compute-mfcc-feats的要求,最终还是调用Compute函数。
Compute
void OfflineFeatureTpl::Compute(
const VectorBase &wave,
BaseFloat vtln_warp,
Matrix *output) {
KALDI_ASSERT(output != NULL);
int32 rows_out = NumFrames(wave.Dim(), computer_.GetFrameOptions()),
cols_out = computer_.Dim();
if (rows_out == 0) {
output->Resize(0, 0);
return;
}
output->Resize(rows_out, cols_out);
Vector window; // windowed waveform.
bool use_raw_log_energy = computer_.NeedRawLogEnergy();
for (int32 r = 0; r < rows_out; r++) { // r is frame index.
BaseFloat raw_log_energy = 0.0;
ExtractWindow(0, wave, r, computer_.GetFrameOptions(),
feature_window_function_, &window,
(use_raw_log_energy ? &raw_log_energy : NULL));
SubVector output_row(*output, r);
computer_.Compute(raw_log_energy, vtln_warp, &window, &output_row);
}
首先使用NumFrames计算WAV有多少帧,然后遍历每一帧:使用ExtractWindow抽取每一帧,然后使用computer_.Compute提取特征。
NumFrames
目前默认的方式是snip_edges,也和HTK一致,也就是保证不需要padding,如果往后移动超出范围,那就不要了。
int32 NumFrames(int64 num_samples,
const FrameExtractionOptions &opts,
bool flush) {
int64 frame_shift = opts.WindowShift();
int64 frame_length = opts.WindowSize();
if (opts.snip_edges) {
// with --snip-edges=true (the default), we use a HTK-like approach to
// determining the number of frames-- all frames have to fit completely into
// the waveform, and the first frame begins at sample zero.
if (num_samples < frame_length)
return 0;
else
return (1 + ((num_samples - frame_length) / frame_shift));
// You can understand the expression above as follows: 'num_samples -
// frame_length' is how much room we have to shift the frame within the
// waveform; 'frame_shift' is how much we shift it each time; and the ratio
// is how many times we can shift it (integer arithmetic rounds down).
} else {
// if --snip-edges=false, the number of frames is determined by rounding the
// (file-length / frame-shift) to the nearest integer. The point of this
// formula is to make the number of frames an obvious and predictable
// function of the frame shift and signal length, which makes many
// segmentation-related questions simpler.
//
// Because integer division in C++ rounds toward zero, we add (half the
// frame-shift minus epsilon) before dividing, to have the effect of
// rounding towards the closest integer.
int32 num_frames = (num_samples + (frame_shift / 2)) / frame_shift;
if (flush)
return num_frames;
// note: 'end' always means the last plus one, i.e. one past the last.
int64 end_sample_of_last_frame = FirstSampleOfFrame(num_frames - 1, opts)
+ frame_length;
// the following code is optimized more for clarity than efficiency.
// If flush == false, we can't output frames that extend past the end
// of the signal.
while (num_frames > 0 && end_sample_of_last_frame > num_samples) {
num_frames--;
end_sample_of_last_frame -= frame_shift;
}
return num_frames;
}
}
ExtractWindow
这里提取每一帧的时候会有一些小的trick,比如采样点的偏移、要求采样点是2的幂,否则在后面padding、在ProcessWindow里加pre-emphasis或者dithering(增加很小的随机噪声防止log为0)以及计算每一帧的能量。最后会把这些点乘以窗口函数FeatureWindowFunction(默认为povey自己设计的窗)
void ExtractWindow(int64 sample_offset,
const VectorBase &wave,
int32 f, // with 0 <= f < NumFrames(feats, opts)
const FrameExtractionOptions &opts,
const FeatureWindowFunction &window_function,
Vector *window,
BaseFloat *log_energy_pre_window) {
KALDI_ASSERT(sample_offset >= 0 && wave.Dim() != 0);
int32 frame_length = opts.WindowSize(),
frame_length_padded = opts.PaddedWindowSize();
int64 num_samples = sample_offset + wave.Dim(),
start_sample = FirstSampleOfFrame(f, opts),
end_sample = start_sample + frame_length;
if (opts.snip_edges) {
KALDI_ASSERT(start_sample >= sample_offset &&
end_sample <= num_samples);
} else {
KALDI_ASSERT(sample_offset == 0 || start_sample >= sample_offset);
}
if (window->Dim() != frame_length_padded)
window->Resize(frame_length_padded, kUndefined);
// wave_start and wave_end are start and end indexes into 'wave', for the
// piece of wave that we're trying to extract.
int32 wave_start = int32(start_sample - sample_offset),
wave_end = wave_start + frame_length;
if (wave_start >= 0 && wave_end <= wave.Dim()) {
// the normal case-- no edge effects to consider.
window->Range(0, frame_length).CopyFromVec(
wave.Range(wave_start, frame_length));
} else {
// Deal with any end effects by reflection, if needed. This code will only
// be reached for about two frames per utterance, so we don't concern
// ourselves excessively with efficiency.
int32 wave_dim = wave.Dim();
for (int32 s = 0; s < frame_length; s++) {
int32 s_in_wave = s + wave_start;
while (s_in_wave < 0 || s_in_wave >= wave_dim) {
// reflect around the beginning or end of the wave.
// e.g. -1 -> 0, -2 -> 1.
// dim -> dim - 1, dim + 1 -> dim - 2.
// the code supports repeated reflections, although this
// would only be needed in pathological cases.
if (s_in_wave < 0) s_in_wave = - s_in_wave - 1;
else s_in_wave = 2 * wave_dim - 1 - s_in_wave;
}
(*window)(s) = wave(s_in_wave);
}
}
if (frame_length_padded > frame_length)
window->Range(frame_length, frame_length_padded - frame_length).SetZero();
SubVector frame(*window, 0, frame_length);
ProcessWindow(opts, window_function, &frame, log_energy_pre_window);
}
} // namespace kaldi
FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts) {
int32 frame_length = opts.WindowSize();
KALDI_ASSERT(frame_length > 0);
window.Resize(frame_length);
double a = M_2PI / (frame_length-1);
for (int32 i = 0; i < frame_length; i++) {
double i_fl = static_cast(i);
if (opts.window_type == "hanning") {
window(i) = 0.5 - 0.5*cos(a * i_fl);
} else if (opts.window_type == "hamming") {
window(i) = 0.54 - 0.46*cos(a * i_fl);
} else if (opts.window_type == "povey") { // like hamming but goes to zero at edges.
window(i) = pow(0.5 - 0.5*cos(a * i_fl), 0.85);
} else if (opts.window_type == "rectangular") {
window(i) = 1.0;
} else if (opts.window_type == "blackman") {
window(i) = opts.blackman_coeff - 0.5*cos(a * i_fl) +
(0.5 - opts.blackman_coeff) * cos(2 * a * i_fl);
} else {
KALDI_ERR << "Invalid window type " << opts.window_type;
}
}
}
MfccComputer
最终到了干活的代码了,上面会调用它的Compute函数:
void MfccComputer::Compute(BaseFloat signal_raw_log_energy,
BaseFloat vtln_warp,
VectorBase *signal_frame,
VectorBase *feature) {
KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() &&
feature->Dim() == this->Dim());
// 获取Mel FilterBank,为了复用,会把每一个VLTN的alpha作为key存在map里。
const MelBanks &mel_banks = *(GetMelBanks(vtln_warp));
if (opts_.use_energy && !opts_.raw_energy)
// 用向量向量乘法计算能量
signal_raw_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame),
std::numeric_limits::epsilon()));
// FFT,默认是split-radix算法
if (srfft_ != NULL) // Compute FFT using the split-radix algorithm.
srfft_->Compute(signal_frame->Data(), true);
else // An alternative algorithm that works for non-powers-of-two.
RealFft(signal_frame, true);
// FFT得到的复数计算其模得到功率谱
// Convert the FFT into a power spectrum.
ComputePowerSpectrum(signal_frame);
SubVector power_spectrum(*signal_frame, 0,
signal_frame->Dim() / 2 + 1);
// 使用Filter bank滤波器组提取每个滤波器的能量
mel_banks.Compute(power_spectrum, &mel_energies_);
// 避免对零取log (如果有dithering那么不应该是零,但是dithering是可选的,所以保险一点还是要处理)
// avoid log of zero (which should be prevented anyway by dithering).
mel_energies_.ApplyFloor(std::numeric_limits::epsilon());
mel_energies_.ApplyLog(); // take the log.
feature->SetZero(); // in case there were NaNs.
// 进行DCT得到倒谱
// feature = dct_matrix_ * mel_energies [which now have log]
feature->AddMatVec(1.0, dct_matrix_, kNoTrans, mel_energies_, 0.0);
if (opts_.cepstral_lifter != 0.0)
feature->MulElements(lifter_coeffs_);
// 如果使用能量,那么把这一帧的能量替换掉倒谱的第一个系数。
if (opts_.use_energy) {
if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_)
signal_raw_log_energy = log_energy_floor_;
(*feature)(0) = signal_raw_log_energy;
}
if (opts_.htk_compat) {
BaseFloat energy = (*feature)(0);
for (int32 i = 0; i < opts_.num_ceps - 1; i++)
(*feature)(i) = (*feature)(i+1);
if (!opts_.use_energy)
energy *= M_SQRT2; // scale on C0 (actually removing a scale
// we previously added that's part of one common definition of
// the cosine transform.)
(*feature)(opts_.num_ceps - 1) = energy;
}
}
最终我们可以使用计算的函数compute-mfcc-feats.cc做一个简单的实验:
compute_cmvn_stats.sh
之后看顶层脚本compute_cmvn_stats.sh,地址:https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/compute_cmvn_stats.sh。使用方式如下:
steps/compute_cmvn_stats.sh
Usage: steps/compute_cmvn_stats.sh [options] [ [] ]
e.g.: steps/compute_cmvn_stats.sh data/train exp/make_mfcc/train mfcc
Note: defaults to /log, and defaults to /data
Options:
--fake gives you fake cmvn stats that do no normalization.
--two-channel is for two-channel telephone data, there must be no segments
file and reco2file_and_channel must be present. It will take
only frames that are louder than the other channel.
--fake-dims Generate stats that won't cause normalization for these
dimensions (e.g. 13:14:15)
compute-cmvn-stats --spk2utt=ark:$data/spk2utt scp:$data/feats.scp ark,scp:$cmvndir/cmvn_$name.ark,$cmvndir/cmvn_$name.scp \
2> $logdir/cmvn_$name.log && echo "Error computing CMVN stats. See $logdir/cmvn_$name.log"
上图为脚本里的核心函数:compute-cmvn-stats,使用方法及参数如下:
compute-cmvn-stats
Compute cepstral mean and variance normalization statistics
If wspecifier provided: per-utterance by default, or per-speaker if
spk2utt option provided; if wxfilename: global
Usage: compute-cmvn-stats [options] (|)
e.g.: compute-cmvn-stats --spk2utt=ark:data/train/spk2utt scp:data/train/feats.scp ark,scp:/foo/bar/cmvn.ark,data/train/cmvn.scp
See also: apply-cmvn, modify-cmvn-stats
Options:
--binary : write in binary mode (applies only to global CMN/CVN) (bool, default = true)
--spk2utt : rspecifier for speaker to utterance-list map (string, default = "")
--weights : rspecifier for a vector of floats for each utterance, that's a per-frame weight. (string, default = "")
Standard options:
--config : Configuration file to read (this option may be repeated) (string, default = "")
--help : Print out usage message (bool, default = false)
--print-args : Print the command line arguments (to stderr) (bool, default = true)
--verbose : Verbose level (higher->more logging) (int, default = 0)
很多参数,但是必须的参数只有两个:输入的feat-rspecifier和输出的stats-wspecifier。我们首先需要准备输入的文件,即上一步提取mfcc特征的输出文件raw_mfcc。
compute-cmvn-stats.cc
int main(int argc, char *argv[]) {
try {
using namespace kaldi;
using kaldi::int32;
const char *usage =
"Compute cepstral mean and variance normalization statistics\n"
"If wspecifier provided: per-utterance by default, or per-speaker if\n"
"spk2utt option provided; if wxfilename: global\n"
"Usage: compute-cmvn-stats [options] (|)\n"
"e.g.: compute-cmvn-stats --spk2utt=ark:data/train/spk2utt"
" scp:data/train/feats.scp ark,scp:/foo/bar/cmvn.ark,data/train/cmvn.scp\n"
"See also: apply-cmvn, modify-cmvn-stats\n";
ParseOptions po(usage);
std::string spk2utt_rspecifier, weights_rspecifier;
bool binary = true;
po.Register("spk2utt", &spk2utt_rspecifier, "rspecifier for speaker to utterance-list map");
po.Register("binary", &binary, "write in binary mode (applies only to global CMN/CVN)");
po.Register("weights", &weights_rspecifier, "rspecifier for a vector of floats "
"for each utterance, that's a per-frame weight.");
po.Read(argc, argv);
if (po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
}
int32 num_done = 0, num_err = 0;
std::string rspecifier = po.GetArg(1);
std::string wspecifier_or_wxfilename = po.GetArg(2);
RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier);
if (ClassifyWspecifier(wspecifier_or_wxfilename, NULL, NULL, NULL)
!= kNoWspecifier) { // writing to a Table: per-speaker or per-utt CMN/CVN.
std::string wspecifier = wspecifier_or_wxfilename;
DoubleMatrixWriter writer(wspecifier);
if (spk2utt_rspecifier != "") {
SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
RandomAccessBaseFloatMatrixReader feat_reader(rspecifier);
for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
std::string spk = spk2utt_reader.Key();
const std::vector &uttlist = spk2utt_reader.Value();
bool is_init = false;
Matrix stats;
for (size_t i = 0; i < uttlist.size(); i++) {
std::string utt = uttlist[i];
if (!feat_reader.HasKey(utt)) {
KALDI_WARN << "Did not find features for utterance " << utt;
num_err++;
continue;
}
const Matrix &feats = feat_reader.Value(utt);
if (!is_init) {
InitCmvnStats(feats.NumCols(), &stats);
is_init = true;
}
if (!AccCmvnStatsWrapper(utt, feats, &weights_reader, &stats)) {
num_err++;
} else {
num_done++;
}
}
if (stats.NumRows() == 0) {
KALDI_WARN << "No stats accumulated for speaker " << spk;
} else {
writer.Write(spk, stats);
}
}
} else { // per-utterance normalization
SequentialBaseFloatMatrixReader feat_reader(rspecifier);
for (; !feat_reader.Done(); feat_reader.Next()) {
std::string utt = feat_reader.Key();
Matrix stats;
const Matrix &feats = feat_reader.Value();
InitCmvnStats(feats.NumCols(), &stats);
if (!AccCmvnStatsWrapper(utt, feats, &weights_reader, &stats)) {
num_err++;
continue;
}
writer.Write(feat_reader.Key(), stats);
num_done++;
}
}
} else { // accumulate global stats
if (spk2utt_rspecifier != "")
KALDI_ERR << "--spk2utt option not compatible with wxfilename as output "
<< "(did you forget ark:?)";
std::string wxfilename = wspecifier_or_wxfilename;
bool is_init = false;
Matrix stats;
SequentialBaseFloatMatrixReader feat_reader(rspecifier);
for (; !feat_reader.Done(); feat_reader.Next()) {
std::string utt = feat_reader.Key();
const Matrix &feats = feat_reader.Value();
if (!is_init) {
InitCmvnStats(feats.NumCols(), &stats);
is_init = true;
}
if (!AccCmvnStatsWrapper(utt, feats, &weights_reader, &stats)) {
num_err++;
} else {
num_done++;
}
}
Matrix stats_float(stats);
WriteKaldiObject(stats_float, wxfilename, binary);
KALDI_LOG << "Wrote global CMVN stats to "
<< PrintableWxfilename(wxfilename);
}
KALDI_LOG << "Done accumulating CMVN stats for " << num_done
<< " utterances; " << num_err << " had errors.";
return (num_done != 0 ? 0 : 1);
} catch(const std::exception &e) {
std::cerr << e.what();
return -1;
}
}
根据每个说话人的音频集合进行AccCmvnStatsWrapper函数的处理。
AccCmvnStatsWrapper
bool AccCmvnStatsWrapper(const std::string &utt,
const MatrixBase &feats,
RandomAccessBaseFloatVectorReader *weights_reader,
Matrix *cmvn_stats) {
if (!weights_reader->IsOpen()) {
AccCmvnStats(feats, NULL, cmvn_stats);
return true;
} else {
if (!weights_reader->HasKey(utt)) {
KALDI_WARN << "No weights available for utterance " << utt;
return false;
}
const Vector &weights = weights_reader->Value(utt);
if (weights.Dim() != feats.NumRows()) {
KALDI_WARN << "Weights for utterance " << utt << " have wrong dimension "
<< weights.Dim() << " vs. " << feats.NumRows();
return false;
}
AccCmvnStats(feats, &weights, cmvn_stats);
return true;
}
}
AccCmvnStatsWrapper主要是对一些特殊情况进行处理,主要干活的是AccCmvnStats。
AccCmvnStats
void AccCmvnStats(const VectorBase &feats, BaseFloat weight, MatrixBase *stats) {
int32 dim = feats.Dim();
KALDI_ASSERT(stats != NULL);
KALDI_ASSERT(stats->NumRows() == 2 && stats->NumCols() == dim + 1);
// Remove these __restrict__ modifiers if they cause compilation problems.
// It's just an optimization.
double *__restrict__ mean_ptr = stats->RowData(0),
*__restrict__ var_ptr = stats->RowData(1),
*__restrict__ count_ptr = mean_ptr + dim;
const BaseFloat * __restrict__ feats_ptr = feats.Data();
*count_ptr += weight;
// Careful-- if we change the format of the matrix, the "mean_ptr < count_ptr"
// statement below might become wrong.
for (; mean_ptr < count_ptr; mean_ptr++, var_ptr++, feats_ptr++) {
*mean_ptr += *feats_ptr * weight;
*var_ptr += *feats_ptr * *feats_ptr * weight;
}
}
这里将satate的第一行构造为传入feat的每一列的行加权和,第一行最后一维是feat的帧数。state的第二行是传入的feat的每一列的行加权平方和。
最终我们可以使用函数compute-mfcc-feats.cc和compute-cmvn-stats.cc做一个简单的实验: