PyTorch Lightning教程六:优化代码

有时候模型训练很慢,代码写得冗长之后,没法诶个检查到底那块出现了占用了时空间,本节通过利用Lightning的一些方法,检查分析是那块代码出现了问题,从而来进一步指导和优化代码

本节主要基于性能分析方法,通过捕获分析信息(例如函数花费的时间或使用了多少内存)帮助我们找到代码中的瓶颈。

找到训练时候的瓶颈

最基本的性能分析配置文件,包含训练中Callback、DataModules和LightningModule中的所有关键方法。可以通过如下方法引入

trainer = Trainer(profiler="simple")

一旦执行.fit()方法,则可以看到如下类似结果

FIT Profiler Report
-----------------------------------------------------------------------------------------------
|  Action                                          |  Mean duration (s)     |  Total time (s) |
-----------------------------------------------------------------------------------------------
|  [LightningModule]BoringModel.prepare_data       |  10.0001               |  20.00          |
|  run_training_epoch                              |  6.1558                |  6.1558         |
|  run_training_batch                              |  0.0022506             |  0.015754       |
|  [LightningModule]BoringModel.optimizer_step     |  0.0017477             |  0.012234       |
|  [LightningModule]BoringModel.val_dataloader     |  0.00024388            |  0.00024388     |
|  on_train_batch_start                            |  0.00014637            |  0.0010246      |
|  [LightningModule]BoringModel.teardown           |  2.15e-06              |  2.15e-06       |
|  [LightningModule]BoringModel.on_train_start     |  1.644e-06             |  1.644e-06      |
|  [LightningModule]BoringModel.on_train_end       |  1.516e-06             |  1.516e-06      |
|  [LightningModule]BoringModel.on_fit_end         |  1.426e-06             |  1.426e-06      |
|  [LightningModule]BoringModel.setup              |  1.403e-06             |  1.403e-06      |
|  [LightningModule]BoringModel.on_fit_start       |  1.226e-06             |  1.226e-06      |
-----------------------------------------------------------------------------------------------

在这个打印出来的报告中,我们可以看到最慢的函数是prepare_data,现在我们可以弄清楚为什么数据准备会减慢训练速度。执行profiler="simple",会包括:

  • on_train_epoch_start
  • on_train_epoch_end
  • on_train_batch_start
  • model_backward
  • on_after_backward
  • optimizer_step
  • on_train_batch_end
  • on_training_end
  • 等等……

分析每个函数内的时间

要分析每个函数花费的时间,使用构建在Python的cProfiler之上的AdvancedProfiler,如下引用:

trainer = Trainer(profiler="advanced")

执行fit后,会出现如下结果

Profiler Report

Profile stats for: get_train_batch
        4869394 function calls (4863767 primitive calls) in 18.893 seconds
Ordered by: cumulative time
List reduced from 76 to 10 due to restriction <10>
ncalls  tottime  percall  cumtime  percall filename:lineno(function)
3752/1876    0.011    0.000   18.887    0.010 {built-in method builtins.next}
    1876     0.008    0.000   18.877    0.010 dataloader.py:344(__next__)
    1876     0.074    0.000   18.869    0.010 dataloader.py:383(_next_data)
    1875     0.012    0.000   18.721    0.010 fetch.py:42(fetch)
    1875     0.084    0.000   18.290    0.010 fetch.py:44()
    60000    1.759    0.000   18.206    0.000 mnist.py:80(__getitem__)
    60000    0.267    0.000   13.022    0.000 transforms.py:68(__call__)
    60000    0.182    0.000    7.020    0.000 transforms.py:93(__call__)
    60000    1.651    0.000    6.839    0.000 functional.py:42(to_tensor)
    60000    0.260    0.000    5.734    0.000 transforms.py:167(__call__)

如果分析器报告变得太长,可以将报告流式传输到一个文件:

from lightning.pytorch.profilers import AdvancedProfiler

profiler = AdvancedProfiler(dirpath=".", filename="perf_logs")
trainer = Trainer(profiler=profiler)

很方便!

分析加速器使用情况

另一种检测瓶颈的有用技术,是确保正在使用加速器(GPU/TPU/IPU/HPU)的全部容量。这可以用DeviceStatsMonitor来测量:

from lightning.pytorch.callbacks import DeviceStatsMonitor

trainer = Trainer(callbacks=[DeviceStatsMonitor()])

CPU指标将在CPU加速器上默认跟踪。设置DeviceStatsMonitor(cpu_stats=True)为其他加速器启用它。要禁用记录CPU指标,可以指定DeviceStatsMonitor(cpu_stats=False)

你可能感兴趣的:(pytorch,人工智能,python)