transformers下载模型到本地(tensorflow2.0)

tensorflow2.0环境,下载transformers模型到本地

要点: GPT2LMHeadModel.from_pretrained的参数需要额外加入from_tf=True

首先打开网址:https://huggingface.co/models 这个网址是huggingface/transformers支持的所有模型,目前大约一千多个。搜索gpt2(其他的模型类似,比如bert-base-uncased等),并点击进去。
通常我们需要保存的是三个文件及一些额外的文件,第一个是配置文件;config.json。第二个是词典文件,vocab.json。第三个是预训练模型文件,如果你使用pytorch则保存pytorch_model.bin文件,如果你使用tensorflow 2,则保存tf_model.h5。

额外的文件,指的是merges.txt、special_tokens_map.json、added_tokens.json、tokenizer_config.json、sentencepiece.bpe.model等,这几类是tokenizer需要使用的文件,如果出现的话,也需要保存下来。没有的话,就不必在意。

使用下载好的本地文件
使用的时候,非常简单。huggingface的transformers框架主要有三个类model类、configuration类、tokenizer类,这三个类,所有相关的类都衍生自这三个类,他们都有from_pretained()方法和save_pretrained()方法。

from_pretrained方法的第一个参数都是pretrained_model_name_or_path,这个参数设置为我们下载的文件目录即可。

果是tensorflow 2版本的,GPT2LMHeadModel.from_pretrained的参数需要额外加入from_tf=True
以bert-base-uncased为例:

from transformers import AutoTokenizer, TFAutoModel
path = "./transformers_model/bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(path,from_tf=True)

你可能感兴趣的:(transformers,tensorflow)