【PyTorch Lightning】1.0 正式发布:从 0 到 1

目录

一、Lightning DNA

二、1.0.0 的新功能

三、研究 + 生产

四、网站

五、度量 (Metrics)

六、手动优化与自动优化

七、日志 (Logging)

八、数据流 (data flow)

九、检查点 (Checkpointing)


PyTorch Lightning 是基于 PyTorch 的高级框架,简洁易用,在云上大规模部署很有优势。

  • 作者:PyTorch Lightning team
  • 编译:McGL

还记得那个看起来像 Keras 的轻量版 PyTorch 框架 Lightning 吗?它终于出了 1.0.0 版本,并增添了很多新功能,在度量、优化、日志记录、数据流、检查点等方面均进行了完善。

Keras 和 PyTorch 都是对初学者非常友好的深度学习框架,两者各有优势,很多研究者和开发者在选择框架时可能会举棋不定。基于这种情况,grid.ai CEO、纽约大学博士 William Falcon 创建了 PyTorch Lightning,为 PyTorch 披上了一件 Keras 的外衣。

Lightning 是 PyTorch 非常轻量级的包装,研究者只需要编写最核心的训练和验证逻辑,其它过程都会自动完成。因此这就有点类似 Keras 那种高级包装,它隐藏了绝大多数细节,只保留了最通俗易懂的接口。Lightning 能确保自动完成部分的正确性,对于核心训练逻辑的提炼非常有优势。

近期,PyTorch Lightning 在推特宣布,1.0.0 版本现在可用了,并发布新的博客文章详细描述了 PyTorch Lightning 的运行原理和新的 API。William Falcon 表示自己非常期待有一天,当用户查看 GitHub 上的复杂项目时,深度学习代码不再那么令人望而生畏。

特斯拉 AI 负责人 Andrej Karpathy 也评论称:「这看起来很棒,也很有前途。PyTorch Lightning 倡导对深度学习代码进行重构,将『工程(硬件)』与『科学(代码)』分割开,然后将前者委托给框架。」

【PyTorch Lightning】1.0 正式发布:从 0 到 1_第1张图片

在过去的几个月里,我们一直在努力工作,微调 API,改进文档,录制教程,现在终于是时候与大家分享 PyTorch Lightning 的 V1.0.0版了。想要云上缩放模型的极速方案吗?请继续阅读。


一、Lightning DNA

AI 研究的发展速度远远超过任何单一框架所能跟上的速度。深度学习的领域不断发展,主要是在复杂性和规模。Lightning 提供了一个为复杂模型交互世界设计的用户体验,同时抽象出所有令人分心的工程细节,如多 GPU 和多 TPU 训练,early stopping,日志等......

像 PyTorch 这样的框架是为 AI 研究主要关注网络架构的时代而设计的。nn.Module 模块可以定义操作顺序。

以下是 VGG16 的代码结构:

【PyTorch Lightning】1.0 正式发布:从 0 到 1_第2张图片

这些框架在为研究或生产提供极其复杂的模型所需的所有部件方面做出了令人难以置信的工作。但是一旦模型开始相互作用,比如 GAN,BERT,或者自动编码器,这种模式就会打破,巨大的灵活性很快就会变成样板,项目上了规模就很难维护。

与之前出现的框架不同,PyTorch Lightning 被设计成封装一系列相互作用的模型,我们称之为 深度学习系统(deep learning systems)。Lightning 是为当今世界更复杂的研究和生产案例而设计的,在这些案例中,许多模型使用复杂的规则相互作用。

以下展示了一个自编码器系统示意图:

【PyTorch Lightning】1.0 正式发布:从 0 到 1_第3张图片

PyTorch Lightning 的第二个关键原则是硬件和“科学”代码必须分开。Lightning 进化到可以利用大规模的计算,而不需要向用户展示任何抽象概念。通过这种分离,你获得了以前不可能的新能力,比如在笔记本电脑上使用 CPU 调试你的512 GPU 作业而不需要更改代码。

最后,Lightning 创建的愿景是成为一个社区驱动的框架。

构建优秀的深度学习模型需要大量的专业知识和使系统工作的小技巧。在世界各地,数以百计令人难以置信的工程师和博士们一遍又一遍地实现相同的代码。Lightning 现在有一个不断增长的贡献者社区,其中有超过300个极有才华的深度学习人员,他们选择分配相同的能量,做完全相同的优化,但是却有成千上万的人从他们的努力中受益。

【PyTorch Lightning】1.0 正式发布:从 0 到 1_第4张图片


二、1.0.0 的新功能

Lightning 1.0.0 标志着一个 稳定的 最终版  API

这意味着依赖于 Lightning 的主要研究项目可以放心使用,他们的代码在未来不会中断或改变。


三、研究 + 生产

Lightning 的核心优势是使最先进的 AI 研究可以大规模扩展。这是一个为专业研究人员设计的框架,在最大的计算资源上尝试最难的想法,而不会失去任何灵活性。

我们很兴奋地宣布,Lightning 1.0.0 现在 还可以 轻松地大规模部署这些模型。所有的 Lightning 代码确保了所有的东西都可以轻松导出到 onnx 和 torchscript。

# ----------------------------------
# torchscript
# ----------------------------------
autoencoder = LitAutoEncoder()
torch.jit.save(autoencoder.to_torchscript(), "model.pt")
os.path.isfile("model.pt")

# ----------------------------------
# onnx
# ----------------------------------
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
     autoencoder = LitAutoEncoder()
     input_sample = torch.randn((1, 28 * 28))
     autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True)
     os.path.isfile(tmpfile.name)

因此,这意味着你的数据科学家、研究人员等团队现在还可以成为将模型投入生产的人。他们不需要庞大的机器学习工程师团队。

这是 领先的公司使用 Lightning 的一个主要原因: 作为一种帮助他们大大缩短生产时间而不失去任何研究所需的灵活性的方法。

这正是我们企业级服务提供的: Grid AI 是我们在云上进行规模训练的原生平台。Grid 允许任何构建深度学习模型的人在大规模计算资源上迭代,然后立即将这些模型部署到一个可伸缩的环境中,能够处理你扔给深度学习系统的最大流量。以下是 Grid 训练 简图:

【PyTorch Lightning】1.0 正式发布:从 0 到 1_第5张图片


四、网站

你还会注意到,我们已经整合了所有的博客文章,极速的视频教程,社区项目和其他资源在我们的全新主页下,展示所有的东西快如闪电!

【PyTorch Lightning】1.0 正式发布:从 0 到 1_第6张图片


五、度量 (Metrics)

pytorch_lightning.metrics 是一个为了在 PyTorch 和 PyTorch Lightning 中方便度量开发和使用而创建的度量 API。更新的 API 提供了一种内置方法,可以跨多个 GPU (进程)计算每步的度量,同时存储统计信息,允许你在一个 epoch 结束时计算度量,而不必担心与分布式后端相关的任何复杂性。

它对所有的边缘情况都进行了严格的测试,并且包含了越来越多的常用度量实现,比如 Accuracy、 Precision、 Recall、 Fbeta、 MeanSquaredError 等等。

class LitModel(pl.LightningModule):
    def __init__(self):
        ...
        self.train_acc = pl.metrics.Accuracy()
        self.valid_acc = pl.metrics.Accuracy()

    def training_step(self, batch, batch_idx):
        logits = self(x)
        ...
        self.train_acc(logits, y)
        # log step metric
        self.log('train_acc_step', self.train_acc)

    def validation_step(self, batch, batch_idx):
        logits = self(x)
        ...
        self.valid_acc(logits, y)
        # logs epoch metrics
        self.log('valid_acc', self.valid_acc)

要实现自定义度量,只需子类化基本 Metric 类并实现 __init__()、 update() 和 compute() 方法。你所需要做的就是正确调用 add _ state (),以便使用 DDP 实现自定义度量。使用 add_state() 添加的度量状态变量调用 reset()。

from pytorch_lightning.metrics import Metric

class MyAccuracy(Metric):

    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        preds, target = self._input_format(preds, target)
        assert preds.shape == target.shape
        self.correct += torch.sum(preds == target)
        self.total += target.numel()
 
    def compute(self):
        return self.correct.float() / self.total

六、手动优化与自动优化

使用 Lightning,你不需要担心什么时候启用/禁用梯度,做一个后向传播,或者更新优化器,只要你从 training_step 返回一个附加图(graph),Lightning 将自动优化。

def training_step(self, batch, batch_idx):
    loss = self.encoder(batch[0])
    return loss

然而,对于某些研究,比如 GAN、强化学习或者有多个优化器或者内部循环的东西,你可以关闭自动优化,自己完全控制训练循环。

首先,关闭自动优化:

trainer = Trainer(automatic_optimization=False)

现在你控制了训练循环!

def training_step(self, batch, batch_idx, opt_idx):
    (opt_a, opt_b, opt_c) = self.optimizers()
    loss_a = self.generator(batch[0])
    # use this instead of loss.backward so we can automate half
    # precision, etc...
    self.manual_backward(loss_a, opt_a, retain_graph=True)
    self.manual_backward(loss_a, opt_a)
    opt_a.step()
    opt_a.zero_grad()
    loss_b = self.discriminator(batch[0])
    self.manual_backward(loss_b, opt_b)
    ...

七、日志 (Logging)

Lightning 使得 loggers 的集成变得非常简单——只需在 LightningModule 的任何地方调用 log() 方法,它就会将记录的数量发送到你选择的 logger。默认情况下我们使用 Tensorboard,但是你可以选择任何你想用的支持的 logger。

def training_step(self, batch, batch_idx):
  self.log('my_metric', x)

根据 .log () 的调用位置,Lightning 自动确定何时应该进行日志记录(每步或每个epoch) ,但是当然你可以通过手动使用 on_step 和 on_epoch 选项来覆盖默认行为。设置为 on_epoch = True 将在整个训练 epoch 期间累积你的日志值。

def training_step(self, batch, batch_idx):
  self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

八、数据流 (data flow)

我们 deprecate 了 EvalResult 和 TrainResult,这有利于简化数据流,并在训练和验证循环中将日志与数据解耦。

每个循环(训练、验证、测试)都有三个可以实现的钩子(hooks):

  • x_step
  • x_step_end
  • x_epoch_end

为说明数据是如何流动的,我们将使用训练循环(即: x = training)

outs = []
for batch in data:
  out = training_step(batch)
  outs.append(out)
training_epoch_end(outs)

你在 training_step 中返回的任何东西都可以作为 training_epoch_end 的输入。

def training_step(self, batch, batch_idx):
  prediction = …
  return {'loss': loss, 'preds': prediction}

def training_epoch_end(self, training_step_outputs):
  for out in training_step_outputs:
    prediction = out['preds']
  # do something with these

验证和测试步骤也是如此: validation_step 或 test_step 中返回的任何内容都可以用作 { validation/test }_step_end 或 { validation/test }_epoch_end 的输入。如果你使用 DP 或 DDP2分布式模式(即: 拆分 batch 到不同的 GPU) ,请使用 x_step_end 手动聚合(或者不实现它,让 lightning 自动聚合)。


九、检查点 (Checkpointing)

Lightning 现在自动为你保存一个 checkpoint 在你的当前工作目录,还有你的最后一个训练 epoch 的状态。这样可以确保在训练被中断的情况下继续进行训练。

你可以自定义 checkpointing 行为来监控任意数量的训练或验证步骤。例如,如果你想根据验证损失更新 checkpoint:

  1. 计算你希望监控的任何指标或其他数量,例如验证集损失。
  2. 使用 log() 方法记录值,并用一个键如 val_loss。
  3. 初始化 ModelCheckpoint 回调,并设置监视器为你所记录值的键。
  4. 回调传递给 checkpoint_callback Trainer flag。
from pytorch_lightning.callbacks import ModelCheckpoint

class LitAutoEncoder(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.backbone(x)

        # 1. calculate loss
        loss = F.cross_entropy(y_hat, y)

        # 2. log `val_loss`
        self.log('val_loss', loss)

# 3. Init ModelCheckpoint callback, monitoring 'val_loss'
checkpoint_callback = ModelCheckpoint(monitor='val_loss')

# 4. Pass your callback to checkpoint_callback trainer flag
trainer = Trainer(checkpoint_callback=checkpoint_callback)

请在我们的 release notes (https://github.com/PyTorchLightning/pytorch-lightning/releases)中阅读所有的 API 变化,其中包括很多 bug 的修复。


参考文献

https://medium.com/pytorch/pytorch-lightning-1-0-from-0-600k-80fc65e2fab0

https://www.jiqizhixin.com/articles/2020-10-22-10

你可能感兴趣的:(【PyTorch,Lightning】,1024程序员节)