TensorFlow版本的预训练模型与Pytorch版本的预训练模型的转换

文章目录

  • TensorFlow版本的预训练模型与Pytorch版本的预训练模型的转换
    • 1. 方法1:Transformers的转换脚本
    • 2. 方法2:直接下载
      • 2.1 第一种:连网下载
      • 2.2 第二种:无法连网

TensorFlow版本的预训练模型与Pytorch版本的预训练模型的转换

1. 方法1:Transformers的转换脚本

方法1:请自行通过Transformers提供的转换脚本进行转换。

将TensorFlow版本模型转为Pytorch版本:https://huggingface.co/transformers/converting_tensorflow_models.html

具体每个模型如何转为相应的Pytorch版本:https://github.com/huggingface/transformers/tree/master/src/transformers/models里面有各个模型的covert_modelname_original_tf_checkpoint_to_pytorch.py文件,如bert的https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py

python convert_xlnet_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path TF_CHECKPOINT_PATH=/xlnet/chinese_xlnet_base_L-12_H-768_A-12/xlnet_model.ckpt --xlnet_config_file XLNET_CONFIG_FILE=/xlnet/chinese_xlnet_base_L-12_H-768_A-12/xlnet_config.json --pytorch_dump_folder_path PYTORCH_DUMP_FOLDE=/xlnet/chinese_xlnet_base_L-12_H-768_A-12/xlnet_model.bin

2. 方法2:直接下载

2.1 第一种:连网下载

第一种,需要连网,参照https://www.bilibili.com/read/cv8231417/

from transformers import *

model_name = 'hfl/chinese-xlnet-base'

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModel.from_pretrained(model_name) 

运行后系统会自动下载相关的模型文件并存放在电脑中,模型保存的路径在 ~/.cache/torch/transformers/ 目录下 。这个方法下载较快。

TensorFlow版本的预训练模型与Pytorch版本的预训练模型的转换_第1张图片

2.2 第二种:无法连网

第二种,对于无法连网,可以提前下载保存。通过huggingface官网直接下载PyTorch版权重:https://huggingface.co/hfl

具体操作:

  1. 点击任意需要下载的model →
  2. 进去Model card页面,切换Tab点击"files and versionsl" →
  3. 下载bin和json文件。

TensorFlow版本的预训练模型与Pytorch版本的预训练模型的转换_第2张图片

你可能感兴趣的:(NLP,AI,深度学习,pytorch,tensorflow,深度学习)