这篇文章主要介绍为什么使用pytorch时,需要使用Lightning的最常见问题。由Pytorch Lightning的主创团队编写(William Falcon),经本文翻译。
PyTorch非常易于使用,可以构建复杂的AI模型。但是一旦研究变得复杂,并且将诸如多GPU训练,16位精度和TPU训练之类的东西混在一起,用户很可能引入Bug。
PyTorch Lightning完全解决了这个问题。Lightning会构建您的PyTorch代码,以便可以抽象出训练的细节。这使得AI研究可扩展且可快速迭代。
该系列上一篇地址: pytorch-lightning入门(一)—— 初了解
PyTorch Lightning是在NYU和FAIR进行博士研究时创建的
PyTorch Lightning是为从事AI研究的专业研究人员
和博士生
而创建的。
Lightning来自我的博士学位。人工智能研究的纽约大学CILVR和Facebook的AI研究。结果,该框架被设计为具有极强的可扩展性,同时又使最先进的AI研究技术(例如TPU训练)变得微不足道。
现在,核心贡献者都在使用Lightning推动AI的发展,并继续添加新的炫酷功能。
但是,简单的界面使专业的生产团队和新手可以使用Pytorch和PyTorch Lightning社区开发的最新技术。
Lightning拥有超过320名贡献者,由11名研究科学家,博士研究生和专业深度学习工程师组成的核心团队。
本教程将引导您构建一个简单的MNIST分类器
,并排显示PyTorch和PyTorch Lightning代码。虽然Lightning可以构建任何任意复杂的系统,使用MNIST来说明如何将PyTorch代码重构为PyTorch Lightning。
完整的代码可在此Colab Notebook中获得。
典型的AI研究项目
在研究项目中,我们通常希望确定以下关键组成部分:
设计一个三层全连接神经网络,该网络以28x28的图像作为输入,并输出10个可能标签上的概率分布。
首先,在PyTorch中定义模型
该模型定义了计算图,以将MNIST图像作为输入,并将其转换为10个类别(0-9数字)的概率分布。
3-layer network (illustration by: William Falcon)
要将模型转换为PyTorch Lightning,只需将pl.LightningModule
替换掉nn.Module
Lightning 提供了结构化的 PyTorch code
看!两者的代码完全相同!
这意味着可以像使用PyTorch模块一样完全使用LightningModule,例如预测
或者用于预训练
在本教程中,使用MNIST。
同样,PyTorch中的代码与Lightning中的代码相同。
数据集被添加到数据加载器Dataloader
中,该数据加载器处理数据集的加载
,shuffling
, batching
。
简而言之,数据准备包括四个步骤:
同样,除了将PyTorch代码组织为4个函数之外,代码完全相同:
对于此代码中的一些关键函数解释如下:
此功能处理下载和任何数据处理。此功能可确保当您使用多个GPU时,不会下载多个数据集或对数据进行双重操作。
这是因为每个GPU将执行相同的PyTorch,从而导致重复。所有在Lightning的代码可以确保关键部件是从所谓的仅一个GPU。
** train_dataloader,val_dataloadertest_dataloader**
每一个都负责返回适当的数据拆分。Lightning以这种方式进行构造,因此非常清楚如何操作数据。如果曾经阅读用PyTorch编写的随机github代码,则几乎看不到如何操纵数据。
Lightning甚至允许多个数据加载器进行测试或验证。
这段代码是根据我们所谓的DataModule进行组织的。尽管这是100%可选的,并且闪电可以直接使用DataLoaders,但DataModule可以使您的数据可重用并且易于共享。
现在选择如何进行优化。将使用Adam
而不是SGD,因为它在大多数DL研究中都是很好的默认设置。
同样,这两者完全相同,只是它被组织到configure optimizers
功能中。
Lightning非常容易扩展。例如,如果想使用多个优化器(即GAN),则可以在此处返回两者。
还会注意到,在Lightning中,传入了self.parameters() 而不是model,因为LightningModule就是model。
对于n向分类,要计算交叉熵损失。交叉熵与将使用的NegativeLogLikelihood(log_softmax)相同。
再次……代码是完全一样的!
Training and Validation Loop
我们汇总了训练所需的所有关键要素:
现在,执行一个完整的训练例程,该例程执行以下操作:
优化
在PyTorch和Lightning中,伪代码都看起来像这样
但这就是 Lightning不同的地方。在PyTorch中,自己编写了for循环,这意味着必须记住要以正确的顺序调用正确的东西-这为错误留下了很多空间。
即使模型很简单,也不会像开始做更高级的事情那样,例如使用多个GPU,梯度裁剪,提早停止,检查点,TPU训练,16位精度等……代码复杂性将迅速爆炸。
即使模型很简单,也不会一开始就做更高级的事情
这是PyTorch和Lightning的验证和训练循环
这就是Lightning的美。它是抽象模板,内容保持不变, 只做了结构上的调整。这意味着您仍在编写PyTorch,除了您的代码结构变得良好。
这增加了可读性,有助于再现性!
The trainer
is how we abstract the boilerplate code.
同样,这是可能的,因为要做的就是将PyTorch代码组织到LightningModule中
用PyTorch编写的完整MNIST示例如下:
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import os
# -----------------
# MODEL
# -----------------
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self):
super(LightningMNISTClassifier, self).__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.sizes()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
# layer 1
x = self.layer_1(x)
x = torch.relu(x)
# layer 2
x = self.layer_2(x)
x = torch.relu(x)
# layer 3
x = self.layer_3(x)
# probability distribution over labels
x = torch.log_softmax(x, dim=1)
return x
# ----------------
# DATA
# ----------------
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
# train (55,000 images), val split (5,000 images)
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
mnist_test = MNIST(os.getcwd(), train=False, download=True)
# The dataloaders handle shuffling, batching, etc...
mnist_train = DataLoader(mnist_train, batch_size=64)
mnist_val = DataLoader(mnist_val, batch_size=64)
mnist_test = DataLoader(mnist_test, batch_size=64)
# ----------------
# OPTIMIZER
# ----------------
pytorch_model = MNISTClassifier()
optimizer = torch.optim.Adam(pytorch_model.parameters(), lr=1e-3)
# ----------------
# LOSS
# ----------------
def cross_entropy_loss(logits, labels):
return F.nll_loss(logits, labels)
# ----------------
# TRAINING LOOP
# ----------------
num_epochs = 1
for epoch in range(num_epochs):
# TRAINING LOOP
for train_batch in mnist_train:
x, y = train_batch
logits = pytorch_model(x)
loss = cross_entropy_loss(logits, y)
print('train loss: ', loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
# VALIDATION LOOP
with torch.no_grad():
val_loss = []
for val_batch in mnist_val:
x, y = val_batch
logits = pytorch_model(x)
val_loss.append(cross_entropy_loss(logits, y).item())
val_loss = torch.mean(torch.tensor(val_loss))
print('val_loss: ', val_loss.item())
Lightning中的完整训练循环
Lightning版本完全相同,除了:
Trainer
抽象化import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import os
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self):
super().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
# layer 1 (b, 1*28*28) -> (b, 128)
x = self.layer_1(x)
x = torch.relu(x)
# layer 2 (b, 128) -> (b, 256)
x = self.layer_2(x)
x = torch.relu(x)
# layer 3 (b, 256) -> (b, 10)
x = self.layer_3(x)
# probability distribution over labels
x = torch.log_softmax(x, dim=1)
return x
def cross_entropy_loss(self, logits, labels):
return F.nll_loss(logits, labels)
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
self.log('train_loss', loss)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
self.log('val_loss', loss)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# data
# transforms for images
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# prepare transforms standard to MNIST
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
train_dataloader = DataLoader(mnist_train, batch_size=64)
val_loader = DataLoader(mnist_test, batch_size=64)
# train
model = LightningMNISTClassifier()
trainer = pl.Trainer()
trainer.fit(model, train_dataloader, val_loader)
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import os
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self):
super().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
# layer 1 (b, 1*28*28) -> (b, 128)
x = self.layer_1(x)
x = torch.relu(x)
# layer 2 (b, 128) -> (b, 256)
x = self.layer_2(x)
x = torch.relu(x)
# layer 3 (b, 256) -> (b, 10)
x = self.layer_3(x)
# probability distribution over labels
x = torch.log_softmax(x, dim=1)
return x
def cross_entropy_loss(self, logits, labels):
return F.nll_loss(logits, labels)
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
self.log('train_loss', loss)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
self.log('val_loss', loss)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
class MNISTDataModule(pl.LightningDataModule):
def setup(self, stage):
# transforms for images
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# prepare transforms standard to MNIST
self.mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
self.mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=64)
def val_dataloader(self):
return DataLoader(self.mnist_test, batch_size=64)
data_module = MNISTDataModule()
# train
model = LightningMNISTClassifier()
trainer = pl.Trainer()
trainer.fit(model, data_module)
结构化
的def training_step(self, batch, batch_idx):
x, y = batch
# define your own forward and loss calculation
hidden_states = self.encoder(x)
# even as complex as a seq-2-seq + attn model
# (this is just a toy, non-working example to illustrate)
start_token = ''
last_hidden = torch.zeros(...)
loss = 0
for step in range(max_seq_len):
attn_context = self.attention_nn(hidden_states, start_token)
pred = self.decoder(start_token, attn_context, last_hidden)
last_hidden = pred
pred = self.predict_nn(pred)
loss += self.loss(last_hidden, y[step])
#toy example as well
loss = loss / max_seq_len
return {'loss': loss}
还有tensorboard日志
还有 checkpointing, and early stopping
但是Lightning以开箱即用的东西(例如TPU训练等)而闻名。
在Lightning中,可以在CPU,GPU,多个GPU或TPU上训练模型,而无需更改PyTorch代码的一行。
Trainer(precision=16)
使用Tensorboard的其他5种替代方法进行记录点击查看
使用Neptune.AI进行日志记录(鸣谢:Neptune.ai)
使用Comet.ml记录
甚至有一个内置的探查器profiler,可以告诉训练中瓶颈的位置。
trainer = Trainer(..., profiler=True)
将此标志设置为True, 将提供如下输出
或更高级的输出(如果需要)
还可以一次在多个GPU上进行训练而无需做任何工作(仍然必须提交SLURM作业)
它支持大约40种其他功能,可以在文档中阅读这些功能。
可能想知道Lightning如何为做到这一点,又以某种方式做到这一点,以便完全掌控一切?
与keras或其他高级框架不同,Lightning不会隐藏任何必要的细节。但是,如果确实需要自己修改训练的各个方面,那么有两个主要选择。
首先是通过覆盖钩子( hooks)的可扩展性。这是一个非详尽的清单:
回调是希望在训练的各个部分执行的一段代码。在Lightning中,回调保留用于非必需的代码,例如日志记录或与研究代码无关的东西。这使研究代码保持超级干净和有条理。
假设想在训练的各个部分打印或保存一些内容。这是回调的样子
PyTorch Lightning回调
这种范例将研究代码组织在三个不同的存储库中