Huggingface走到4.8.2这个版本,已经有了很好的封装。训练一个语言网络只需要调用Trainer.train(...)即可完成。如果要根据自己的需求修改训练的过程,比如自定义loss,输出梯度,直接修改huggingface的源码显然是不可取的了。好在huggingface提供了相应的接口,让我们可以深入到训练过程中,加入自定义的内容。根据官方的教程,有两种推荐的方法:
关于trainer和callbacks这两个的官方文档分别是这里和这里,这两个方法都可以很优雅地修改原有的逻辑。但个人感觉重载trainer的方法是一种更灵活也更强大的方法。callbacks其实只能查看提供的一些变量,并且也只是查看,不能做出修改。而重载方法可以定义任意的全新的函数。接下来给出这两种方法的两个例子。
在官方给的教程中是一个重载compute loss的例子,这里给一个不一样的,定义trainging_step的例子,代码如下:
class PrintGradientTrainer(Trainer):
def training_step(self, model, inputs):
model.train()
inputs = self._prepare_inputs(inputs)
loss = self.compute_loss(model, inputs)
loss.backward()
# ------------------------new added codes.--------------------------
for name, param in model.named_parameters():
if param.requires_grad:
if param.grad is not None:
print("{}, gradient: {}".format(name, param.grad.mean()))
else:
print("{} has not gradient".format(name))
# ------------------------new added codes.--------------------------
return loss.detach()
# originally the Trainer() is called
#trainer = Trainer(
# model=model, args=training_args, train_dataset=small_train_dataset, #eval_dataset=small_eval_dataset,
# tokenizer=tokenizer, data_collator=data_collator
#)
# Now call the new defined PrintGradientTrainer()
trainer = PrintGradientTrainer(
model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset,
tokenizer=tokenizer, data_collator=data_collator
)
trainer.train()
只给出了关键部分的代码,其他的就按照正常写即可。
这个方法也需要定义一个原本的TrainerCallback的子类,然后重载原有的空的callbacks方法。代码实例如下,这个例子打出了现在是第几个epoch。
class MyCallback(TrainerCallback):
def on_step_begin(self, args, state, control, **kwargs):
print("train step start")
control.should_log = False
control.should_evaluate = False
control.should_save = False
print('---------------------------------------',state.epoch)
# return self.call_event("on_step_begin", args, state, control)
trainer = PrintGradientTrainer(
model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset,
tokenizer=tokenizer, data_collator=data_collator,callbacks=[MyCallback()]
)
在定义trainer的时候,给callbacks加入自己定义的类就可以了。