transformers Bert微调Trainer

参考:https://huggingface.co/docs/transformers/training
https://github.com/huggingface/transformers/blob/master/examples/pytorch/text-classification/run_xnli.py
transformers Bert微调Trainer_第1张图片

**代码案例:

## load datas
from datasets import load_dataset

raw_datasets = load_dataset("imdb")
##  输入训练数据构建
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000)) 
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000)) 
#full_train_dataset = tokenized_datasets["train"]
#full_eval_dataset = tokenized_datasets["test"]

## 加载预训练的模型进行微调

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)

from transformers import TrainingArguments

training_args = TrainingArguments("test_trainer")

from transformers import Trainer

trainer = Trainer(
    model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset
)
trainer.train()  ### 开始训练
trainer.save_model()  ## 保存
trainer.predict() ## 预测

你可能感兴趣的:(深度学习,bert,深度学习,pytorch)