探究torchAudio中wav2vec2的源码(一)

由于上一篇博客,我把torchAudio中的wav2vec2样例加上自己理解发了出来,这次我们就来看看torchaudio中,wav2vec2.0的模型是怎么创建的。

博主也在边写博客边看这源码学习,理解得不一定对,有错希望大佬指出~

观前提示:

  1. 由于直接打开torchaudio项目,没导环境,因此会有红线,并不是代码问题。
  2. 上一篇博客的地址:https://blog.csdn.net/weixin_43142450/article/details/123831419?spm=1001.2014.3001.5502

pipelines

首先我们看回这个样例的一行和创建模型很相关的代码。

bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H

我在理解中,这句话就等于创建一个wav2vec2语音识别的模型工厂了。我们看看里面是什么。

探究torchAudio中wav2vec2的源码(一)_第1张图片

可以看到引用的bundle在pipelines初始化文件中。

这个文件主要引用了文件夹_wav2vec2中的文件impl.py的数据类。

探究torchAudio中wav2vec2的源码(一)_第2张图片

可以看出WAV2VEC2_ASR_960H 等于创建了一个Wav2Vec2ASRBundle数据类。这个数据类里的数据,看着就是特征提取和分类的模型构造参数。然后还有最终的字母标签获取_labels和采样率设置_sample_rate

我们继续递归,看看这个Wav2Vec2ASRBundle数据类的结构:

探究torchAudio中wav2vec2的源码(一)_第3张图片

就这点东西。Wav2Vec2ASRBundle数据类继承了Wav2Vec2Bundle数据类。然后自己新增了标签属性_labels,和一个不知什么东西的位置_remove_aux_axis。里面还定义了get_labels方法和_get_state_dict方法

然后我们再看看被Wav2Vec2ASRBundle类继承的Wav2Vec2Bundle类是什么:

探究torchAudio中wav2vec2的源码(一)_第4张图片

里面有模型路径_path、模型参数_param、采样率_sample_rate,以及获取采样率的方法sample_rate、获取状态字典方法_get_state_dict、获取模型方法get_model。

(博主表示不知道状态字典的作用是什么,后面看到了再进行修改,或者有大佬指点也可以,感谢)

在我写了样例博客的单元块三中有一行代码,如是写道:

model = bundle.get_model().to(device)

可以看出,他就是调用了Wav2Vec2Bundle类中的get_model方法。通过get_model方法能直接建立好我们的wav2vec2模型。因此,我们仔细看看get_model方法的代码。

探究torchAudio中wav2vec2的源码(一)_第5张图片

我们点进wav2vec2_model方法看看

models

它跳到了models文件夹中的model.py文件。

探究torchAudio中wav2vec2的源码(一)_第6张图片

我们发现传参就是我们的模型结构参数。

探究torchAudio中wav2vec2的源码(一)_第7张图片

然后做了特征提取工作feature_extractor和预训练transformer模型创建工作encoder,以及下面的线性转换aux。最后返回一个Wav2Vec2Models回去。

而aux_num_out代表什么,博主不知道,博主表示压根不知道aux代表什么东西。因此希望各位好学者推测推测或者大牛给个答案?

今天就先水到这了,明天再接着看。

你可能感兴趣的:(语音识别,python,语音识别,语言模型)