方法1:请自行通过Transformers提供的转换脚本进行转换。
将TensorFlow版本模型转为Pytorch版本:https://huggingface.co/transformers/converting_tensorflow_models.html
具体每个模型如何转为相应的Pytorch版本:https://github.com/huggingface/transformers/tree/master/src/transformers/models里面有各个模型的covert_modelname_original_tf_checkpoint_to_pytorch.py
文件,如bert
的https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py
python convert_xlnet_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path TF_CHECKPOINT_PATH=/xlnet/chinese_xlnet_base_L-12_H-768_A-12/xlnet_model.ckpt --xlnet_config_file XLNET_CONFIG_FILE=/xlnet/chinese_xlnet_base_L-12_H-768_A-12/xlnet_config.json --pytorch_dump_folder_path PYTORCH_DUMP_FOLDE=/xlnet/chinese_xlnet_base_L-12_H-768_A-12/xlnet_model.bin
第一种,需要连网,参照https://www.bilibili.com/read/cv8231417/
from transformers import *
model_name = 'hfl/chinese-xlnet-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
运行后系统会自动下载相关的模型文件并存放在电脑中,模型保存的路径在 ~/.cache/torch/transformers/
目录下 。这个方法下载较快。
第二种,对于无法连网,可以提前下载保存。通过huggingface官网直接下载PyTorch版权重:https://huggingface.co/hfl
具体操作: