谷歌原版bert模型tf转pytorch命令

从TensorFlow 检查点转换

提供了一个命令行界面,用于将原始的Bert/GPT/GPT-2/Transformer-XL/XLNet/XLM检查点转换为可以使用库的from_pretrained方法加载的模型。
从 2.3.0 开始,转换脚本现在是任何可用的transformers CLI(transformer-cli)的一部分
transformers安装版本 >= 2.3.0
下面的文档反映了转换器-cli 转换命令格式。

BERT

可以使用convert_bert_original_tf_checkpoint_to_pytorch.py脚本在 PyTorch 保存文档中转换 BERT 的任何 TensorFlow 检查点(特别是 Google 发布的预训练模型)。

此 CLI 将 TensorFlow 检查点(三个以 bert_model.ckpt 开头的文档)和关联的配置文档 (bert_config.json) 作为输入,并为此配置创建一个 PyTorch 模型,从 PyTorch 模型中的 TensorFlow 检查点加载权重,并将生成的模型保存在可以使用 from_pretrained() 导入的标准 PyTorch 保存文档中(请参阅快速浏览中的示例 quicktour, run_glue.py)。

只需运行一次此转换脚本即可获得 PyTorch 模型。然后,可以忽略 TensorFlow 检查点(以 bert_model.ckpt 开头的三个文档),但请务必保留配置文档 (\ bert_config.json) 和词汇文档 (vocab.txt),因为 PyTorch 模型也需要这些文档。

要运行这个特定的转换脚本,你需要安装TensorFlow和PyTorch(pip install tensorflow)。存储库的其余部分只需要 PyTorch。

以下是预训练的BERT-Base Uncased模型的转换过程示例:

export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12

transformers-cli convert --model_type bert \
  --tf_checkpoint $BERT_BASE_DIR/bert_model.ckpt \
  --config $BERT_BASE_DIR/bert_config.json \
  --pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin

可以在此处下载 Google 预先训练的转换模型。

你可能感兴趣的:(pytorch,bert,tensorflow)