PyTorch Lightning基础入门

Lightning in 15 minutes

Lightning in 15 minutes — PyTorch Lightning 2.0.4 documentation

安装 PyTorch Lightning

pip install lightning

conda install lightning -c conda-forge

定义一个LightningModule

LightningModule可以让pytorchnn.Module可以整合一些训练过程(也可以有验证和测试)。

如下是一个手写数字识别自动编码器(autoencoder)的样例:

import os
import torch
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning.pytorch as pl

'''
定义两个模型,编码器和解码器,这个是pytorch的模型对象
'''
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# 定义LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # 训练步骤
        # 这个跟 forward 不相关
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # 存储日志(需要安装Tensorboard)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
				# 优化器
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# 初始化自动编码器
autoencoder = LitAutoEncoder(encoder, decoder)

定义数据集

Lightning支持所有可迭代的数据集形式(DataLoadernumpy,以及其他)。

# setup 
datadataset=MNIST(os.getcwd(),download=True,transform=ToTensor())
train_loader=utils.data.DataLoader(dataset)

训练模型

LightningTrainer对象可以整合LightningModule与不同数据集,并扩展了一些工程所需方法。

# 训练模型
trainer = pl.Trainer(limit_train_batches=100, max_epochs=10)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

Trainer对象也实现了很多常用的过程:

  1. Epochbatch迭代。
  2. optimizer.step()loss.backward()optimizer.zero_grad()
  3. 验证过程中的**model.eval()。**
  4. 模型存储和载入
  5. Tensorboard
  6. 多GPU
  7. TPU
  8. 半精度混合

【注意】:在jupyter下,多卡训练可能会报错,可以试试直接用python代码。

使用模型

训练完模型后,可以导出到 onnx、torchscript 并将其投入生产,或者只是加载权重并运行预测。

# 载入模型
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# 选择训练好的编码器
encoder = autoencoder.encoder
encoder.eval()

# 编码图片
fake_image_batch = torch.randn(8, 28 * 28).to(next(encoder.parameters()).device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

训练可视化

如果安装了Tensorboard,可以用它来观察实验过程。

tensorboard --logdir .

额外训练设置

# 4gpu训练
trainer = Trainer(
    devices=4,
    accelerator="gpu",
 )

# train 1TB+ parameter models with Deepspeed/fsdp
# 使用 Deepspeed 训练大模型
trainer = Trainer(
    devices=4,
    accelerator="gpu",
    strategy="deepspeed_stage_2",
    precision=16
 )

# 20+ helpful flags for rapid idea iteration
# 有助于快速迭代的一些设置
trainer = Trainer(
    max_epochs=10,
    min_epochs=5,
    overfit_batches=1
 )

# access the latest state of the art techniques
# 获取最新的技术
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])

一些灵活设置

定制训练循环

PyTorch Lightning基础入门_第1张图片

LightningModule中设置了20多种断点(HOOK),可以用来定制训练过程:

class LitAutoEncoder(pl.LightningModule):
    def backward(self, loss):
        loss.backward()

扩展Trainer

PyTorch Lightning基础入门_第2张图片

在上面这个代码种,对模型的存储进行了一些设置。这些设置可以在pl.Callback对象中实现,并导入Trainer对象。

PyTorch Lightning基础入门_第3张图片

用如下方式可以导入Trainer:

trainer = Trainer(callbacks=[AWSCheckpoints()])

你可能感兴趣的:(pytorch学习,pytorch,人工智能,python)