Fairseq的wav2vec2的踩坑之旅4:如何手动将一个Fairseq的wav2vec2模型转换为transformers的模型

摘要:

本文尝试将用中文拼音预训练的Fairseq的wav2vec2模型转换为transformers模型(以下简写trms),因为汉语拼音的label数量与英文不同,所以本文需要进行模型转换函数的修改。

自己预训练和finetune的模型没有稳定输出,但是应该是label转换的问题
本文可能对“复现党”有一定的参考价值

文章目录

  • 摘要:
  • 1.分析transofrmers模型的结构
  • 2.使用transformers的工具进行导入
    • 2.1 导入工具参数说明
    • 2.2 创建对应的tokenizer需要的文件
  • 3.测试转换后的trms模型的效果
  • 4.使用Transformers对汉语拼音进行finetune

1.分析transofrmers模型的结构

huggingface下载的模型默认保存在~/.cache/huggingface下面,如果需要离线使用,则需要将其保存到一个常见可见的目录,方便手动管理。

在模型目录下一般包括如下的文件:

  • config.json 模型配置文件,项目配置文件
  • vocab.json 编解码器的字典文件,json格式,字典:key是label,值是id
  • pytorch_model.bin trms转换后的pytorch模型文件
  • special_tokens_map.json 编解码器的特数据字符
  • tokenizer_config.json 编解码器的配置文件

使用fairseq.checkpoint_utils.load_model_ensemble_and_task([fname])尝试导入pytorch_model.bin,报错,分析是从huggingface下载的模型是没有fairseq的task/args/cfg等信息。

    with open(local_path, "rb") as f:
        state = torch.load(f, map_location=torch.device("cpu"))
        
    #分析类型state是

Tips: 这里有一个方便的下载各个模型的小工具,下载模型到具体目录保存。

from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
import argparse
import os

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="facebook/wav2vec2-base-100h", help="pretrained model name")
    parser.add_argument("--save_dir", type=str, default="openai-gpt", help="pretrained model name")
    args = parser.parse_args()
    print(args)    
    #save model
    save_dir = os.path.expanduser(args.save_dir)    
    # load model and tokenizer
    tokenizer = Wav2Vec2Tokenizer.from_pretrained(args.model_name) 
    tokenizer.save_pretrained(save_dir)
    model = Wav2Vec2ForCTC.from_pretrained(args.model_name)
    model.save_pretrained(save_dir)
    
if __name__=="__main__":
    """
    $prj=~/Documents/projects/transformers/facebook/wav2vec2-base-100h
    python ~/Documents/workspace/fairseq2trms/downloadModel.py --model_name="facebook/wav2vec2-base-100h"  --save_dir=$prj
    """
    main()

2.使用transformers的工具进行导入

**说明:**为什么不直接Fairseq,而是要用transformers呢?

  1. fairseq存在比较严重的过渡封装问题,接口复杂,omgaconf传参工具不容易迁移,不适合作生产环境部署
  2. fairseq做评估和ASR应用需要flashlight,由于防火墙的存在,基本上是无法按照官方教程安装的(vcpkg和编译都不容易)
  3. trms的接口比较直接明确,工具链比较简单

2.1 导入工具参数说明

trms本身提供了从fairseq导入wav2vec2模型的工具:transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py

使用如下脚本可以从自己训练的模型转换trms用的模型:

python -m  transformers.models.wav2vec2.convert_wav2vec2_original_pytorch_checkpoint_to_pytorch  --pytorch_dump_folder_path "~/Documents/projects/transformers/bostenai/960h-zh_CN" --checkpoint_path "~/Documents/projects/Fairseq/ModelSerial/CTC-softlink-slr18_bst-0310/ctc_d1/outputs/checkpoints2/checkpoint_best.pt" --dict_path "~/Documents/projects/transformers/bostenai/960h-zh_CN/dict.ltr.txt" 
#具体参数含义如下
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="transformers模型导出到该目录.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="需要转换的fairseq的wav2vec2模型文件")
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model,但是从代码看,这就是填充data字段防止fairseq无法加载的占为参数,随便给一个就好,也就是为什么不用fairseq的原因,代码质量太差!")
#在内部如果指定了config将直接使用,因此不能指定config的path
#dict_path 貌似是有歧义的参数名,只是使用用来替换data字段的,不传会错误,传递了也没啥用
parser.add_argument("--config_path", default=None, type=str, help="转换后的配置文件,因为是转换所以留空不传,传了出错!")
parser.add_argument(
    "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
)

Error1: 缩略~导致的问题

trms并没有对路径参数调用os.path.expanduser(),导致无法推倒含的路径。但是处于个人隐私考虑,全文还是使用代替传递参数中的绝对路径,请自行替换,下同。

requests.exceptions.HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/~/Documents/projects/transformers/bostenai/960h-zh_CN//resolve/main/config.json
During handling of the above exception, another exception occurred:
- '~/Documents/projects/transformers/bostenai/960h-zh_CN/' is a correct model identifier listed on 'https://huggingface.co/models'
- or '~/Documents/projects/transformers/bostenai/960h-zh_CN/' is the correct path to a directory containing a config.json file

Error2: config.json

如果指定了config.json参数,而该文件不能自行生成,将导致如下错误,直接留空不传

OSError: file /home/xxx/Documents/projects/transformers/bostenai/960h-zh_CN/config.json not found
During handling of the above exception, another exception occurred:

Error3: Exception “omegaconf.errors.ValidationError”

这是Fairseq的Bug和功能缺陷,Omgaconf实在太复杂,当不传递–dict_path时导致:data等参数无法填充

Exception "omegaconf.errors.ValidationError"
Non optional field cannot be assigned None
	full_key: data
	reference_type=Optional[AudioPretrainingConfig]
	object_type=AudioPretrainingConfig
File: ~/.conda/envs/lSrv39/lib/python3.9/site-packages/omegaconf/_utils.py, Line: 610

#在内部如果指定了config将直接使用,因此不能指定config的path
#dict_path 貌似是有歧义的参数名,只是使用用来替换data字段的,不传会错误,传递了也没啥用

Error4:finetune层的维度不一致的问题

注意:因为我们是zh_CN识别的模型,而官方是英文,一个vocab是26而中文是222,因次原始的加载函数需要魔改

为了能够正确加载模型,我们需要将trms提供的加载函数略作修改

#transformers/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
#对下面的函数进行修改,首先 加载fairseq模型,读取ctc层的配置,最后再创建trms模型进行复制

@torch.no_grad()
def convert_wav2vec2_checkpoint(
    checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
):
    """
    Copy/paste/tweak model's weights to transformers design.
    """
    #首先加载需要转换的模型
    if is_finetuned:
        model, saved_cfg, task= fairseq.checkpoint_utils.load_model_ensemble_and_task(
            [checkpoint_path], arg_overrides={"data": dict_path}
        )
        #注意这里的dict_path是随便赋值的,因为没啥实际作用,仅仅是让fairseq不要出错。在fairseq加载时有大量莫名其妙的参数检查和以来,omgaconf的实现质量十分的差
    else:
        model, saved_cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
    model = model[0].eval()
    
    #创建默认的值
    if  config_path is not None:
        config = Wav2Vec2Config.from_pretrained(config_path)
    else:
        config = Wav2Vec2Config()    
    
    if is_finetuned:
        #修改config的值,这里直接使用了最后一层的值,因为其他地方得不到
        config.vocab_size = list(model.modules())[-1].out_features
        hf_wav2vec = Wav2Vec2ForCTC(config)
    else:
        hf_wav2vec = Wav2Vec2Model(config)

    recursively_load_weights(model, hf_wav2vec, is_finetuned)
    hf_wav2vec.save_pretrained(pytorch_dump_folder_path)

2.2 创建对应的tokenizer需要的文件

我认为fairseq训练的时候使用了默认的配置,那么special_tokens_map.json tokenizer_config.json preprocessor_config.json都可以复制,只需要修改vocab.json

创建vocab.json需要从fairseq的dict.ltr.txt转换,转换脚本如下:

# coding=utf-8
"""
创建汉语拼音使用的字典
"""
import argparse
import os
import json

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dict", type=str, default="dict.ltr.txt", help="finetune时的字典文件")
    parser.add_argument("--vocab", type=str, default="vocab.json", help="输出的字典json")
    args = parser.parse_args()
    print(args)
    #vocab 系统自带和默认的几个key
    oVocab ={"": 0, "": 1, "": 2, "": 3}    
    #加载fairseq的dict.ltr.txt文件
    dictF  = os.path.expanduser(args.dict)
    vocabF = os.path.expanduser(args.vocab)
    with open(dictF, 'r') as df:
        text = df.readlines()
        tLen = len(oVocab)
        #转换为trms的字典
        for it in text:
            key, p = it.strip().split()
            oVocab[key] = tLen
            tLen = tLen+1
    #写出到vocab.json
    with open(vocabF, 'w+', encoding="utf-8") as cf:
        json.dump(oVocab, cf, sort_keys=True, indent=4, separators=(',', ': '), ensure_ascii=False )

    
if __name__=="__main__":
    """
    $prj=~/Documents/projects/transformers/bostenai/960h-zh_CN
    python ~/Documents/workspace/fairseq2trms/create_vocab.py --dict=$prj/dict.ltr.txt  --vocab=$prj/vocab.json
    """
    main()

脚本将创建一个自定义的汉语拼音的vocab文件

到此模型转换的过程基本完成

3.测试转换后的trms模型的效果

使用官方模型进行测试,这里使用local_files_only=True来加载我们下载或者本地封装的模型

#模拟Fairseq官方提供的例子写的一个Demo
from transformers import  Wav2Vec2ForCTC, Wav2Vec2Processor
import argparse
import os
#from pypinyin import  lazy_pinyin, Style
import soundfile as sf
import torch

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="~/Documents/Projects/transformers/bostenai/100h-zh_CN", help="pretrained model path")
    parser.add_argument("--local_files_only", action="store_true", help="是否使用本地模式")
    parser.add_argument("wav", type=str,  help="待分析的音频文件")
    args = parser.parse_args()
    print(args)    
    #save model
    model_path = os.path.expanduser(args.model_path)    
    # load model and tokenizer
    tokenizer = Wav2Vec2Processor.from_pretrained(model_path, local_files_only=args.local_files_only)
    model = Wav2Vec2ForCTC.from_pretrained(model_path, local_files_only=args.local_files_only)
    
    #read in sound file    
    #tStr ="我爱北京天安门"
    #tList = lazy_pinyin(tStr, style=Style.TONE )     
    #print("|".join(tList))
    #tAudio = "/home/linger/Downloads/temp/1.wav"    
    tAudio = os.path.expanduser(args.wav)
    audio_input, sr = sf.read(tAudio)    
    # transcribe
    input_values = tokenizer(audio_input, return_tensors="pt", sampling_rate=sr).input_values
    logits = model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = tokenizer.batch_decode(predicted_ids)[0]
    print(transcription)
    
if __name__=="__main__":
    """
    python /home/linger/Documents/workspace/fairseq2trms/test_trms.py  --model_path ~/Documents/Projects/transformers/bostenai/100h-zh_CN --local_files_only /home/linger/Downloads/temp/1.wav
    """
    main()

如果测试官方模型可能会得到这样的结果:

$ python /home/linger/Documents/workspace/fairseq2trms/test_trms.py  --model_path facebook/wav2vec2-base-960h  --local_files_only /home/linger/Downloads/temp/1.wav
Namespace(model_path='facebook/wav2vec2-base-960h', local_files_only=True, wav='/home/linger/Downloads/temp/1.wav')
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
WA I BA GING KAN IMON

看到英语直接拼音汉语是不是很好玩. 哈哈:WA I BA GING KAN IMON

4.使用Transformers对汉语拼音进行finetune

好尴尬!! 前期预训练的汉语拼音的wav2vec2没有没有输出我想要的结果:

$    python /home/linger/Documents/workspace/fairseq2trms/test_trms.py  --model_path ~/Documents/Projects/transformers/bostenai/100h-zh_CN --local_files_only /home/linger/Downloads/temp/1.wav
Namespace(model_path='/home/linger/Documents/Projects/transformers/bostenai/100h-zh_CN', local_files_only=True, wav='/home/linger/Downloads/temp/1.wav')

w<unk><unk><unk><unk><unk><unk>

那么尝试在我的小本上试试finetune吧

未完待续

你可能感兴趣的:(Fairseq,python,pytorch,深度学习,ubuntu)