来源:投稿 作者:Mr.剑豪
编辑:学姐
现有的人工智能模型基本上逃不出pytorch
和tensorflow
两个种框架,前者对于搞学习做研究来说非常友好,容易上手,后者对于开发来说更能满足需求。
在做实验的时候,常常会遇到代码和模型不匹配的问题。代码需要pytorch
模型还好说,搜一搜huggingface
几乎都有开源模型,但是如果需要tensorflow
模型,找资源就不是那么容易了。
我之前遇到的问题是,用torch
架构训了一个Electra
模型,花了一周时间,但是用的时候发现,有一两个代码还是用的tensorflow
,苦苦搜寻数日得不到结果,只有bert
和roberta
的torch
转tf
(https://github.com/percent4/roberta_torch_2_tf),其他的都没找到或者找到了也没有改成功。后来只好在roberta
的转换基础上修改出了Electra
的转换代码。
其实Electra
,roberta
,bert
大同小异,不论是torch
还是tf
都是一堆模型参数,只是时候结构有些不同,这给不同架构的使用者造成了麻烦。
这个代码已经验证成功了,可以将ELECTRA
从torch
版本转到tf
的ckpt
版本。
需要注意的是tf_1
系列和tf_2
系列还是有很大区别的,这里用的是tensorflow1.15
,2系列没有试过,可能会出问题。
命令如下:
python electra_convert_pytorch_checkpoint_to_tf.py --model_name=electra --config_file="../Electra/electra_small/config.json" --cache_dir="../Electra/electra_small_torch" --tf_cache_dir="../Electra/electra_small_torch/tf
torch转tf代码如下
# Electra torch转tensorflow
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
import os
import argparse
import json
import numpy as np
import tensorflow as tf
#tensorflow = 1.15
tf.enable_eager_execution()
from pytorch_transformers.modeling_roberta import RobertaModel as BertModel
from transformers import ElectraModel
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str, config: dict):
"""
:param model:BertModel Pytorch model instance to be converted
:param ckpt_dir: Tensorflow model directory
:param model_name: model name
:return:
Currently supported HF models:
Y BertModel
N BertForMaskedLM
N BertForPreTraining
N BertForMultipleChoice
N BertForNextSentencePrediction
N BertForSequenceClassification
N BertForQuestionAnswering
"""
tensors_to_transpose = (
"dense.weight",
"attention.self.query",
"attention.self.key",
"attention.self.value",
"embeddings_project.weight"
)
var_map = (
('layer.', 'layer_'),
('word_embeddings.weight', 'word_embeddings'),
('position_embeddings.weight', 'position_embeddings'),
('token_type_embeddings.weight', 'token_type_embeddings'),
('.', '/'),
('LayerNorm/weight', 'LayerNorm/gamma'),
('LayerNorm/bias', 'LayerNorm/beta'),
('weight', 'kernel')
)
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
state_dict = model.state_dict()
def to_tf_var_name(name: str):
for patt, repl in iter(var_map):
name = name.replace(patt, repl)
return 'electra/{}'.format(name)
def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
session.run(tf.variables_initializer([tf_var]))
session.run(tf_var)
return tf_var
tf.reset_default_graph()
with tf.Session() as session:
# print(state_dict)
for var_name in state_dict:
tf_name = to_tf_var_name(var_name)
torch_tensor = state_dict[var_name].numpy()
#Electra在下游任务中只需要判别器部分,生成器部分可以直接忽略
if "generator" in tf_name:
continue
if any([x in var_name for x in tensors_to_transpose]):
torch_tensor = torch_tensor.T
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
tf.keras.backend.set_value(tf_var, torch_tensor)
tf_weight = session.run(tf_var)
print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))
print(torch_tensor.shape)
saver = tf.train.Saver(tf.trainable_variables())
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
def main(raw_args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--model_name",
default="roberta-base",
type=str,
help="model name e.g. bert-base-uncased")
parser.add_argument("--config_file",
default="./roberta_config.json",
type=str,
help="config for Tensorflow model")
parser.add_argument("--cache_dir",
default="./roberta-base",
type=str,
help="Directory containing pytorch model")
parser.add_argument("--tf_cache_dir",
default="./tf-roberta-base",
type=str,
help="Directory in which to save tensorflow model")
args = parser.parse_args(raw_args)
with open(args.config_file, 'r') as inf:
config = json.load(inf)
if args.cache_dir:
model = ElectraModel.from_pretrained(
pretrained_model_name_or_path=args.cache_dir,
cache_dir=args.cache_dir
)
else:
model = ElectraModel.from_pretrained(
pretrained_model_name_or_path=args.model_name,
)
convert_pytorch_checkpoint_to_tf(
model=model,
ckpt_dir=args.tf_cache_dir,
model_name=args.model_name,
config=config,
)
if __name__ == "__main__":
main()
点击卡片关注深度学习干货免费领