tf预训练模型转换为torch预训练模型

在将albert的tensorflow预训练模型转换为 torch类型预训练模型,踩了很多坑。终于解决,希望对大家有用

  1. 前期准备
    创建一个环境带有torch和tf的环境,步骤如下:
    首先创建环境
    python conda create -n torchtf_env python=3.7
    然后,安装torch(根据自己电脑的cuda安装)
    python conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c conda-forge
    之后,继续安装tensorflow-gpu版本
    python conda install tensorflow-gpu==1.15
    最后安装transformers
    pip install transformers

2 .从github上下载tensorflow预训练的albert版本

#! usr/bin/env python3
# -*- coding:utf-8 -*-
"""
Created on 19/03/2021 20:22 
@Author: lixj
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import torch
from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
import logging
logging.basicConfig(level=logging.INFO)

def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = AlbertConfig.from_pretrained(bert_config_file)
    # print("Building PyTorch model from configuration: {}".format(str(config)))
    model = AlbertForPreTraining(config)
    # Load weights from tf checkpoint
    load_tf_weights_in_albert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--tf_checkpoint_path", default='albert_base_en/model.ckpt-best', type=str,  help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--bert_config_file",
        default='albert_base_en/albert_config.json',
        type=str,
        help="The config json file corresponding to the pre-trained BERT model. \n"
        "This specifies the model architecture.",
    )
    parser.add_argument(
        "--pytorch_dump_path", default='albert_base_en/pytorch_model.bin', type=str,help="Path to the output PyTorch model."
    )
    args = parser.parse_args()
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

你可能感兴趣的:(python,自然语言处理,nlp随笔记)