使用Transformers将TF模型转化成PyTorch模型

场景

使用tensorflow将TF模型转化成PyTorch模型

步骤

获取如下三个文件:

  • src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py:这个是将tensorflow2.x Bert模型转化成PyTorch可用的模型。
  • src/transformers/models/bert/modeling_bert.py:Bert模型使用例子。
  • BERT-Base, Multilingual Cased (New, recommended):基于Bert的多语言预训练模型。

这里假设已经安装过PyTorch了。
开始转化TF2模型位PyTorch模型:

# 安装依赖
pip3 install tensorflow transformers
export BERT_BASE_DIR=~/Downloads/nlp_bert/multi_cased_L-12_H-768_A-12
transformers-cli convert --model_type bert \
  --tf_checkpoint $BERT_BASE_DIR/bert_model.ckpt \
  --config $BERT_BASE_DIR/bert_config.json \
  --pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin

这里的pytorch_model.bin就是TF2的已经训练好的模型转化过来的PyTorch模型。

参考:

  • Google预训练模型google-research/bert
  • Converting Tensorflow Checkpoints

你可能感兴趣的:(tensorflow,pytorch,深度学习)