BERT: pytorch与tf 模型互转

背景

BERT是google出的, 理所当然使用tf框架, 但是目前很多项目是使用pytorch框架的, 所以就需要在两个框架之间转换bert模型.

方法

pytorch to tf

主要使用huggingface的转换脚本.但是有几个地方需要修改:

  1. 修改包导入:
    from transformers import BertModelfrom modeling_bert import BertModel
  2. 修改L102, load模型的参数:
    model = BertModel.from_pretrained(
        pretrained_model_name_or_path=args.model_name,
        state_dict=torch.load(args.pytorch_model_path),
        cache_dir=args.cache_dir,
    )
    
    为:
    model = BertModel.from_pretrained(
        state_dict=torch.load(args.pytorch_model_path)
    )
    

最后运行脚本:
python convert_bert_pytorch_checkpoint_to_original_tf.py --model_name --pytorch_model_path --tf_cache_dir
其中model_name随便指定一个即可, 没有影响, 不过需要在当前目录下新建model_name目录, 然后把pytorch模型对应的config.json放到该目录下.其他两个参数就是对应的模型了, 没什么好解释的.

tf to pytorch

也是使用huggingface的转换脚本.
如果是bert base的转换, 改动与上述pytorch to tf类似.
如果是finetune好的分类模型, 就需要修改转换代码了:

  1. modeling_bert.py:117:
for name, array in zip(names, arrays):
    name = name.split("/")

修改为:

for name, array in zip(names, arrays):
    if name in ['output_weights', 'output_bias']:
        name = 'classifier/' + name
    name = name.split("/")
  1. 将转换脚本convert_bert_original_tf_checkpoint_to_pytorch.py:33:
model = BertForPreTraining(config)

修改为

config.num_labels = 2
model = BertForSequenceClassification(config)

最后按照提示进行参数调用即可.

pytorch与tf共存

由于转换过程中要这两个框架,而且都要用到CUDA, tensorflow还得是1.xx版本.
难点: 首先tensorflow1.xx版本不支持cuda10.1, 所以只能用cuda10.0, pytorch绝大多数版本不支持cuda10.0
所以以下是我的解决方案:
方法一(cuda10.0):

  1. 首先新建虚拟环境, 安装tensorflow:
    pip install tensorflow-gpu==1.14
  2. 安装pytorch:
    只有pytorch 1.4版本是支持cuda10.0的, 安装:
    pip install torch==1.4.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
  3. 测试
import torch
import tensorflow as tf

print("Pytorch CUDA: {}".format(torch.cuda.is_available())
print("Tensorflow CUDA: {}".format(tf.test.is_gpu_available())

方法二(cuda10.1):
该方法就是使用conda安装cudatoolkit和cudnn, 然后安装tf 1.xx版本也是可以使用cuda的.

  1. 安装cudatoolkit和cudnn
    conda install cudatoolkit=10.1
    conda install cudnn=7.6.4
  2. 正常安装tf 1.xx和pytorch

你可能感兴趣的:(BERT: pytorch与tf 模型互转)