tensorflow2调用huggingface transformer预训练模型

tensorflow2调用huggingface transformer预训练模型

  • 一点废话
    • huggingface简介
    • 传送门
    • pipline
    • 加载模型
    • 设定训练参数
    • 数据预处理
    • 训练模型
    • 结语

一点废话

好久没有更新过内容了,开工以来就是在不停地配环境,如今调通模型后,对整个流程做一个简单的总结(水一篇)。

现在的NLP行业几乎都逃不过fune-tuning预训练的bert或者transformer这一关,按照传统方法,构建整个模型,在processer里传数据,在util里配路径,在bert_fine_tune.sh里炼丹,说实话很麻烦,面对很多不需要复杂下游任务的任务,直接调用预训练模型是最便捷高效的方法,著名的huggingface正是为此而生的,但它整体面向pytorch,如何有效得在tensorflow2中使用这些模型,对新员工和新入学的盆友们是一件很头疼的事情,那么具体应该做些什么呢?往下看。

huggingface简介

Hugging face原本是一家聊天机器人初创服务商 https://huggingface.co/ 专注于NLP技术,拥有大型的开源社区。尤其是在github上开源的自然语言处理,预训练模型库 Transformers,已被下载超过一百万次,github上超过24000个star。Transformers 提供了NLP领域大量state-of-art的 预训练语言模型结构的模型和调用框架。(https://github.com/huggingface/transformers)

传送门

https://huggingface.co/models
在这里有近年来较为成熟的bert系模型,可以在目录中直接下载预训练的模型使用

pipline

pip install transformers==4.6.1
pip install tensorflow-gpu==2.4.0

需要注意的是,transforers包所需的tensorflow版本至少为2.2.0,而该版本对应的CUDA版本可能不同,如笔者使用的2.4.0版本tensorflow对应的CUDA是11版本,在此祭奠配cuda环境浪费的两天时间……

加载模型

下载好的完整模型是这样的,其中:
config 定义模型参数,如layer_num、batch等
tf_model.h5 tensorflow2模型文件
tokenizer 文本处理模块
vocab 词典
tensorflow2调用huggingface transformer预训练模型_第1张图片
记录好模型所在的目录,然后打开你的编辑器,导入所需的包,这里以序列分类为例,其他下游任务参考官方文档https://huggingface.co/transformers/model_doc/bert.html

import tensorflow as tf
import transformers as ts
from transformer import BertTokenizer, TFBertForSequenceClassification
import pandas as pd

设定训练参数

data_path = 'YOUR_DATA_PATH'
model_path = 'YOUR_MODEL_PATH'
config_path = 'YOUR_CONFIG_PATH'
num_labels = 10
epoch = 10
tokenizer = BertTokenizer.from_pretrained(model_path)

数据预处理

涉及到数据的读入和重组,注意数据格式一定要符合bert模型所需的格式

def data_incoming(path):
    x = []
    y = []
    with open(path, 'r') as f:
        for line in f.readlines():
            line = line.strip('\n')
            line = line.split('\t')
            x.append(line[0])
            y.append(line[1])
    df_row = pd.DataFrame([x, y], index=['text', 'label'])
    df_row = df_row.T
    df_label = pd.DataFrame({"label": ['YOUR_LABEL'], 'y': list(range(10))})
    output = pd.merge(df_row, df_label, on='label', how='left')
    return output

def convert_example_to_feature(review):
    return tokenizer.encode_plus(review,
                                 max_length=256,
                                 pad_tp_max_length=True,
                                 return_attention_mask=True,
                                 truncation=True
                                 )

def map_example_to_dict(input_ids, attention_mask, token_type_ids, label):
    return {
               "input_ids": input_ids,
               "token_type_ids": token_type_ids,
               "attention_mask": attention_mask,
           }, label

def encode_example(ds, limit=-1):
    input_ids_list = []
    token_type_ids_list = []
    attention_maks_list = []
    label_list = []
    if limit > 0:
        ds.take(limit)
    for index, row in ds.iterrows():
        review = row["text"]
        label = row['y']
        bert_input = convert_example_to_feature(review)
        input_ids_list.append(bert_input["input_ids"])
        token_type_ids_list.append(bert_input['token_type_ids'])
        attention_maks_list.append(bert_input['attention_maks'])
        label_list.append([label])
    return tf.data.Dataset.from_tensor_slices(
        (input_ids_list, token_type_ids_list, attention_maks_list, label_list)).map(map_example_to_dict)

具体内容就不再赘述了,已经写得很详细了,实在不懂的话……说不定我还有时间看评论

训练模型

def main():
    train = data_incoming(data_path + 'train.tsv')
    test = data_incoming(data_path + 'test.tsv')
    train = encode_example(train).shuffle(100000).batch(100)
    test = encode_example(test).batch(100)
    model = TFBertForSequenceClassification(model_path, num_labels=num_labels)
    optimizer = tf.keras.optimizers.Adam(1e-5)
    model.compile(optimizer=optimizer, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
    model.fit(train, epochs=epoch, verbose=1, validation_data=test)

if __name__ == '__main__':
    main()

结语

呐,就是这么简单的一小块代码,足以让你的gpu煎鸡蛋了,然鹅笔者再摸鱼可能就要被组长裹进鸡蛋里煎了,就先到这吧,下一期……有没有还不晓得,996不配有空闲

你可能感兴趣的:(深度学习,tensorflow,深度学习,人工智能,python)