ELECTRA Pytorch bin 转tensorflow ckpt

来源:投稿 作者:Mr.剑豪
编辑:学姐

现有的人工智能模型基本上逃不出pytorchtensorflow两个种框架,前者对于搞学习做研究来说非常友好,容易上手,后者对于开发来说更能满足需求。

在做实验的时候,常常会遇到代码和模型不匹配的问题。代码需要pytorch模型还好说,搜一搜huggingface几乎都有开源模型,但是如果需要tensorflow模型,找资源就不是那么容易了。

我之前遇到的问题是,用torch架构训了一个Electra模型,花了一周时间,但是用的时候发现,有一两个代码还是用的tensorflow,苦苦搜寻数日得不到结果,只有bertrobertatorchtf(https://github.com/percent4/roberta_torch_2_tf),其他的都没找到或者找到了也没有改成功。后来只好在roberta的转换基础上修改出了Electra的转换代码。

其实Electrarobertabert大同小异,不论是torch还是tf都是一堆模型参数,只是时候结构有些不同,这给不同架构的使用者造成了麻烦。

这个代码已经验证成功了,可以将ELECTRAtorch版本转到tfckpt版本。

需要注意的是tf_1系列和tf_2系列还是有很大区别的,这里用的是tensorflow1.15,2系列没有试过,可能会出问题。

命令如下:

python electra_convert_pytorch_checkpoint_to_tf.py --model_name=electra --config_file="../Electra/electra_small/config.json" --cache_dir="../Electra/electra_small_torch" --tf_cache_dir="../Electra/electra_small_torch/tf

torch转tf代码如下

# Electra torch转tensorflow
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
import os
import argparse
import json
import numpy as np
import tensorflow as tf
#tensorflow = 1.15

tf.enable_eager_execution()
from pytorch_transformers.modeling_roberta import RobertaModel as BertModel
from transformers import ElectraModel

def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str, config: dict):
    """
    :param model:BertModel Pytorch model instance to be converted
    :param ckpt_dir: Tensorflow model directory
    :param model_name: model name
    :return:
    Currently supported HF models:
        Y BertModel
        N BertForMaskedLM
        N BertForPreTraining
        N BertForMultipleChoice
        N BertForNextSentencePrediction
        N BertForSequenceClassification
        N BertForQuestionAnswering
    """

    tensors_to_transpose = (
        "dense.weight",
        "attention.self.query",
        "attention.self.key",
        "attention.self.value",
        "embeddings_project.weight"
    )

    var_map = (
        ('layer.', 'layer_'),
        ('word_embeddings.weight', 'word_embeddings'),
        ('position_embeddings.weight', 'position_embeddings'),
        ('token_type_embeddings.weight', 'token_type_embeddings'),
        ('.', '/'),
        ('LayerNorm/weight', 'LayerNorm/gamma'),
        ('LayerNorm/bias', 'LayerNorm/beta'),
        ('weight', 'kernel')
    )

    if not os.path.isdir(ckpt_dir):
        os.makedirs(ckpt_dir)

    state_dict = model.state_dict()

    def to_tf_var_name(name: str):
        for patt, repl in iter(var_map):
            name = name.replace(patt, repl)
        return 'electra/{}'.format(name)

    def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
        tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
        tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
        session.run(tf.variables_initializer([tf_var]))
        session.run(tf_var)
        return tf_var

    tf.reset_default_graph()
    with tf.Session() as session:
        # print(state_dict)
        for var_name in state_dict:
            tf_name = to_tf_var_name(var_name)
            torch_tensor = state_dict[var_name].numpy()

            #Electra在下游任务中只需要判别器部分,生成器部分可以直接忽略
            if "generator" in tf_name:
                continue

            if any([x in var_name for x in tensors_to_transpose]):
                torch_tensor = torch_tensor.T
            tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
            tf.keras.backend.set_value(tf_var, torch_tensor)
            tf_weight = session.run(tf_var)
            print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))
            print(torch_tensor.shape)

        saver = tf.train.Saver(tf.trainable_variables())
        saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))


def main(raw_args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name",
                        default="roberta-base",
                        type=str,
                        help="model name e.g. bert-base-uncased")
    parser.add_argument("--config_file",
                        default="./roberta_config.json",
                        type=str,
                        help="config for Tensorflow model")
    parser.add_argument("--cache_dir",
                        default="./roberta-base",
                        type=str,
                        help="Directory containing pytorch model")
    parser.add_argument("--tf_cache_dir",
                        default="./tf-roberta-base",
                        type=str,
                        help="Directory in which to save tensorflow model")
    args = parser.parse_args(raw_args)

    with open(args.config_file, 'r') as inf:
        config = json.load(inf)

    if args.cache_dir:
        model = ElectraModel.from_pretrained(
            pretrained_model_name_or_path=args.cache_dir,
            cache_dir=args.cache_dir
        )

    else:
        model = ElectraModel.from_pretrained(
        pretrained_model_name_or_path=args.model_name,
        )

    convert_pytorch_checkpoint_to_tf(
        model=model,
        ckpt_dir=args.tf_cache_dir,
        model_name=args.model_name,
        config=config,
    )


if __name__ == "__main__":
    main()

 点击卡片关注深度学习干货免费领

你可能感兴趣的:(粉丝的投稿,深度学习干货,深度学习,tensorflow,pytorch)