环境配置
https://github.com/wenet-e2e/wenet
git clone https://github.com/wenet-e2e/wenet.git # 克隆源码
我们提供了example/aishell/s0/run.sh
关于 aishell-1 数据的配方
配方很简单,我们建议您手动逐个运行每个阶段并检查结果以了解整个过程。
cd example/aishell/s0
bash run.sh --stage -1 --stop-stage -1
bash run.sh --stage 0 --stop-stage 0
bash run.sh --stage 1 --stop-stage 1
bash run.sh --stage 2 --stop-stage 2
bash run.sh --stage 3 --stop-stage 3
bash run.sh --stage 4 --stop-stage 4
bash run.sh --stage 5 --stop-stage 5
bash run.sh --stage 6 --stop-stage 6
您也可以只运行整个脚本
bash run.sh --stage -1 --stop-stage 6
此阶段将 aishell-1 数据下载到本地路径$data
。这可能需要几个小时。
如果您已经下载了数据,请更改$data
变量run.sh
并从.--stage 0
# 准备训练数据
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# Data preparation
local/aishell_data_prep.sh ${data}/data_aishell/wav \
${data}/data_aishell/transcript
fi
在这个阶段,local/aishell_data_prep.sh
将原始的 aishell-1 数据组织成两个文件:
wav_id
和wav_path
wav_id
和text_label
wav.scp
BAC009S0002W0122 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0122.wav
BAC009S0002W0123 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0123.wav
BAC009S0002W0124 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0124.wav
BAC009S0002W0125 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0125.wav
...
text
BAC009S0002W0122 而对楼市成交抑制作用最大的限购
BAC009S0002W0123 也成为地方政府的眼中钉
BAC009S0002W0124 自六月底呼和浩特市率先宣布取消限购后
BAC009S0002W0125 各地政府便纷纷跟进
...
如果您想使用自定义数据进行训练,只需将数据组织成两个文件wav.scp
和text
,然后从.stage 1
example/aishell/s0
使用原始 wav 作为输入,使用TorchAudio在数据加载器中实时提取特征。所以在这一步中,我们只需将训练 wav.scp 和文本文件复制到raw_wav/train/
目录中。
tools/compute_cmvn_stats.py
用于提取全局 cmvn(倒谱均值和方差归一化)统计信息。这些统计数据将用于标准化声学特征。设置cmvn=false
将跳过此步骤。
# 提取可选 cmvn 特征
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# remove the space between the text labels for Mandarin dataset
for x in train dev test; do
cp data/${x}/text data/${x}/text.org
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) \
<(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
> data/${x}/text
rm data/${x}/text.org
done
tools/compute_cmvn_stats.py --num_workers 8 --train_config $train_config \
--in_scp data/${train_set}/wav.scp \
--out_cmvn data/$train_set/global_cmvn
fi
# 生成标签令牌字典
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Make a dictionary"
mkdir -p $(dirname $dict)
echo " 0" > ${dict} # 0 is for "blank" in CTC
echo " 1" >> ${dict} # must be 1
tools/text2token.py -s 1 -n 1 data/train/text | cut -f 2- -d" " \
| tr " " "\n" | sort | uniq | grep -a -v -e '^\s*$' | \
awk '{print $0 " " NR+1}' >> ${dict}
num_token=$(cat $dict | wc -l)
echo " $num_token" >> $dict
fi
dict 是标签标记(我们为 Aishell-1 使用字符)和整数索引之间的映射。
一个示例字典如下
0
1
一 2
丁 3
...
龚 4230
龟 4231
4232
表示 CTC 的空白符号。
表示未知标记,任何词汇表外的标记都将映射到其中。
表示用于基于注意力的编码器解码器训练的语音开始和语音结束符号,并且它们共享相同的 id。此阶段生成 WeNet 所需的格式文件data.list
。中的每一行data.list
都是 json 格式,其中包含以下字段。
key
: 话语的关键wav
: 话语的音频文件路径txt
:话语的标准化转录,转录将在训练阶段即时标记为模型单元。这是一个示例data.list
,请参阅生成的训练特征文件data/train/data.list
。
{"key": "BAC009S0002W0122", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0122.wav", "txt": "而对楼市成交抑制作用最大的限购"}
{"key": "BAC009S0002W0123", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0123.wav", "txt": "也成为地方政府的眼中钉"}
{"key": "BAC009S0002W0124", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0124.wav", "txt": "自六月底呼和浩特市率先宣布取消限购后"}
我们还设计了另一种data.list
命名格式,shard
用于大数据训练。如果您想在大数据集(超过 5k)上应用 WeNet,请参阅gigaspeech(10k 小时)或 wenetspeech(10k 小时),了解如何使用shard
样式data.list
。
NN 模型在此步骤中进行训练。
如果对多 GPU 使用 DDP 模式,我们建议使用dist_backend="nccl"
. 如果 NCCL 不起作用,请尝试使用gloo
或使用torch==1.6.0
Set the GPU ids in CUDA_VISIBLE_DEVICES。例如,设置为使用卡 0,1,2,3,6,7。export CUDA_VISIBLE_DEVICES="0,1,2,3,6,7"
如果您的实验在运行几个 epoch 后由于某些原因(例如 GPU 被其他人意外使用并且内存不足)而终止,您可以从检查点模型继续训练。只需找出完成的 epoch exp/your_exp/
,设置 checkpoint=exp/your_exp/$n.pt
并运行. 然后训练将从 $n+1.pt 继续run.sh --stage 4
神经网络结构、优化参数、损失参数和数据集的配置可以在 YAML 格式文件中设置。
在conf/
中,我们提供了几种模型,例如变压器和构象器。见conf/train_conformer.yaml
参考。
培训需要几个小时。实际时间取决于 GPU 卡的数量和类型。在一台 8 卡 2080 Ti 机器中,50 个 epoch 大约需要不到一天的时间。您可以使用 tensorboard 来监控损失。
tensorboard --logdir tensorboard/$your_exp_name/ --port 12598 --bind_all
dir=exp/conformer
cmvn_opts="--cmvn ${dir}/global_cmvn"
train_config=conf/train_conformer.yaml
data_type=raw
dict=data/dict/lang_char.txt
train_set=train
python3 train.py \
--config $train_config \
--data_type $data_type \
--symbol_table $dict \
--train_data data/$train_set/data.list \
--model_dir $dir \
--cv_data data/dev/data.list \
--num_workers 1 \
$cmvn_opts \
--pin_memory
需要文件:
dict词典文件:words.txt
model:final.pt
训练模型用的配置文件:/train.yaml
cmvn文件:在配置文件里面配置路径
待识别语言列表:data.list # 格式{key,wavscp,text}
解码过程
dir=/root/data/aizm/wenet/pre_modle/20210618_u2pp_conformer_exp
data_type=raw
dict=${dir}/words.txt
decode_checkpoint=${dir}/final.pt
decoding_chunk_size=
ctc_weight=0.5
reverse_weight=0.0
test_dir=$dir/test_attention_rescoring
# 测试的语音内容{key,wavscp,text}
list_name=vad_test
data_list_dir=${list_name}.list
mkdir -p $test_dir
python recognize.py \
--mode "attention_rescoring" \
--config $dir/train.yaml \
--data_type $data_type \
--test_data ${data_list_dir} \
--checkpoint $decode_checkpoint \
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--dict $dict \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_file $test_dir/text_${list_name} \
${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size}
python tools/compute-wer.py --char=1 --v=1 \
data/test/text $test_dir/text > $test_dir/wer_${list_name}
这个阶段展示了如何将一组 wav 识别为文本。它还展示了如何进行模型平均。
如果${average_checkpoint}
设置为true
,则交叉验证集上的最佳${average_num}
模型将被平均以生成增强模型并用于识别。
识别也称为解码或推理。NN的功能将应用于输入的声学特征序列以输出文本序列。
WeNet 中提供了四种解码方法:
ctc_greedy_search
: encoder + CTC 贪婪搜索ctc_prefix_beam_search
:encoder + CTC 前缀波束搜索attention_rescoring
:在基于注意力的解码器上使用编码器输出从 ctc 前缀波束搜索中重新评估 ctc 候选者。一般来说,attention_rescoring 是最好的方法。有关这些算法的详细信息,请参阅U2 论文。
--beam_size
是一个可调参数,较大的光束尺寸可能会获得更好的结果,但也会导致更高的计算成本。
--batch_size
“ctc_greedy_search”和“attention”解码模式可以大于1,“ctc_prefix_beam_search”和“attention_rescoring”解码模式必须为1。
tools/compute-wer.py
将计算结果的单词(或字符)错误率。如果您在没有任何更改的情况下运行配方,您可能会得到 WER ~= 5%。
wenet/bin/export_jit.py
将使用 Libtorch 导出经过训练的模型。导出的模型文件可轻松用于其他编程语言(如 C++)的推理。