from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
acc = accuracy_score(labels, preds)
return {
'accuracy': acc,
'f1': f1,
'precision': precision,
'recall': recall
}
# 加入evaluation_strategy
training_args = TrainingArguments(
output_dir='./classification_results', # output directory
num_train_epochs=5, # total number of training epochs
per_device_train_batch_size=16, # batch size per device during training
per_device_eval_batch_size=64, # batch size for evaluation
warmup_steps=5000, # number of warmup steps for learning rate scheduler
weight_decay=0.01, # strength of weight decay
logging_dir='./classification_logs', # directory for storing logs
evaluation_strategy='steps', # "no": No evaluation is done during training.
# "steps": Evaluation is done (and logged) every steps
# "epoch": Evaluation is done at the end of each epoch.
logging_steps=100,
)
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
# 加入compute_metrics,并定义compute_metrics函数
trainer = Trainer(
model=model, # the instantiated Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
eval_dataset=val_dataset, # evaluation dataset
compute_metrics=compute_metrics
)
trainer.train()
trainer.evaluate()