[transformers]在trainer中使用torch.profiler.profile

今天需要在transformers的trainer的API中使用profile,然后分析模型的性能,但是trainer的封装度比较高,不太好修改,其实可以使用callback的方式完成profile的性能监控。

class MyCallback(TrainerCallback):
    "A callback that prints a message at the beginning of training"
    def __init__(self, prof):
        self.prof = prof

    def on_train_begin(self, args, state, control, **kwargs):
        print("Starting training")

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        self.prof.step()

然后在trianier实例化的时候,传入callback:

with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,
                                        torch.profiler.ProfilerActivity.CUDA], 
                            schedule=torch.profiler.schedule(skip_first=3, wait=1, warmup=1, active=2, repeat=2),
                            on_trace_ready=torch.profiler.tensorboard_trace_handler('hf-training-trainer'),
                            profile_memory=True,
                            with_stack=True,
                            record_shapes=True) as prof:
    
    trainer.add_callback(MyCallback(prof=prof))
    trainer.train()

参考文献

hf_training_trainer_prof.py

Is there a pytorch profiler integration with huggingface trainer?

你可能感兴趣的:(pytorch,深度学习,人工智能)