- https://pytorch-lightning.readthedocs.io/en/latest/
- https://github.com/PyTorchLightning/pytorch-lightning
- https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html
- https://zhuanlan.zhihu.com/p/353985363
- https://zhuanlan.zhihu.com/p/319810661
- https://github.com/miracleyoo/pytorch-lightning-template
Install
pip install pytorch-lightning -i https://pypi.doubanio.com/simple
AutoEncoder
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def forward(self, x):
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])
autoencoder = LitAutoEncoder()
trainer = pl.Trainer()
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
classify-simple
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
class LitClsModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(nn.Linear(28 * 28, 256),nn.BatchNorm1d(256),nn.ReLU(),
nn.Linear(256, 512),nn.BatchNorm1d(512),nn.ReLU(),
nn.Linear(512,10))
def forward(self, x):
out = self.model(x)
return out
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
out = self(x)
loss = F.cross_entropy(out, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
train = MNIST(os.getcwd(),train=True, download=True, transform=transforms.ToTensor())
val = MNIST(os.getcwd(),train=False, download=True, transform=transforms.ToTensor())
model = LitClsModel()
trainer = pl.Trainer(max_epochs=5,gpus=[0],log_every_n_steps=50)
trainer.fit(model, DataLoader(train,32,True), DataLoader(val,32,False))
classify
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint,EarlyStopping,LearningRateMonitor
from pytorch_lightning import loggers as pl_loggers
import numpy as np
import time
import math
import logging
logging.basicConfig(level=logging.INFO)
def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
"""
def f(x):
if x >= warmup_iters:
return 1
alpha = float(x) / warmup_iters
return warmup_factor * (1 - alpha) + alpha
return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
"""
return torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x:min(1.0,x/warmup_iters))
class LitClsModel(pl.LightningModule):
def __init__(self,epochs,warpstep):
super().__init__()
self.model = nn.Sequential(nn.Linear(28 * 28, 256),nn.BatchNorm1d(256),nn.ReLU(),
nn.Linear(256, 512),nn.BatchNorm1d(512),nn.ReLU(),
nn.Linear(512,10))
self.epochs = epochs
self.warpstep = warpstep
def forward(self, x):
out = self.model(x)
return out
def training_step(self, batch, batch_idx):
"""每个step后执行"""
x, y = batch
x = x.view(x.size(0), -1)
out = self(x)
loss = F.cross_entropy(out, y)
self.log('train_loss', loss)
if self.current_epoch == 0:
self.warmup_lr_scheduler.step()
return loss
def on_validation_epoch_start(self) -> None:
self.start = time.time()
def on_validation_epoch_end(self) -> None:
self.end = time.time()
cost_time = self.end - self.start
self.log('cost_time', cost_time)
print('epoch:%d cost time:%.5f'%(self.current_epoch,cost_time))
def validation_step(self, batch, batch_idx):
"""每个step后执行"""
x, y = batch
x = x.view(x.size(0), -1)
out = self(x)
loss = F.cross_entropy(out, y)
self.log('val_loss', loss)
acc = (out.argmax(1)==y).sum()/out.size(0)
self.log('val_acc', acc)
return {'loss': loss, 'acc': acc}
def validation_step_end(self, batch_parts):
"""validation_step 执行完成 执行该函数"""
loss = batch_parts['loss'].item()
acc = batch_parts['acc'].item()
return {'loss': loss, 'acc': acc}
def validation_epoch_end(self, validation_step_outputs):
"""每个epoch后执行"""
loss_list = []
acc_list = []
for out in validation_step_outputs:
loss_list.append(out['loss'])
acc_list.append(out['acc'])
mean_loss = np.mean(loss_list)
mean_acc = np.mean(acc_list)
self.log('val_acc_epoch', mean_acc)
self.log('val_loss_epoch', mean_loss)
learning_rate = self.optimizers().state_dict()['param_groups'][0]['lr']
self.log('learning_rate', learning_rate)
print("epoch:%d acc:%.3f loss:%.3f lr:%.5f"%(self.current_epoch,mean_acc,mean_loss,learning_rate))
def configure_optimizers(self):
optimizer = torch.optim.Adam([param for param in self.parameters() if param.requires_grad],
lr=1e-3,weight_decay=5e-5)
lrf = 0.1
lf = lambda x: ((1 + math.cos(x * math.pi / self.epochs)) / 2) * (1 - lrf) + lrf
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
self.warmup_lr_scheduler = warmup_lr_scheduler(optimizer, self.warpstep, 1/self.warpstep)
return [optimizer],[scheduler]
train = MNIST(os.getcwd(),train=True, download=True, transform=transforms.ToTensor())
val = MNIST(os.getcwd(),train=False, download=True, transform=transforms.ToTensor())
tb_logger = pl_loggers.TensorBoardLogger('logs/')
epochs=5
batch_size = 32
warpstep = len(train)//batch_size//2
model = LitClsModel(epochs,warpstep)
trainer = pl.Trainer(callbacks=[ModelCheckpoint(monitor='val_acc'),EarlyStopping(monitor='val_acc')],
max_epochs=epochs,gpus=[0],
log_every_n_steps=50,
gradient_clip_val=0.1,
precision=16,
accumulate_grad_batches=4,
stochastic_weight_avg=True,
)
trainer.fit(model, DataLoader(train,batch_size,True), DataLoader(val,batch_size,False))
自定义DataModule
- https://zhuanlan.zhihu.com/p/319810661
class MyDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
...blablabla...
def setup(self, stage):
if stage == 'fit' or stage is None:
self.train_dataset = DCKDataset(self.train_file_path, self.train_file_num, transform=None)
self.val_dataset = DCKDataset(self.val_file_path, self.val_file_num, transform=None)
if stage == 'test' or stage is None:
self.test_dataset = DCKDataset(self.test_file_path, self.test_file_num, transform=None)
def prepare_data(self):
pass
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=1, shuffle=True)
dm = MyDataModule(args)
if not is_predict:
checkpoint_callback = ModelCheckpoint(monitor='val_loss')
model = MyModel()
logger = TensorBoardLogger('log_dir', name='test_PL')
dm.setup('fit')
trainer = pl.Trainer(gpus=gpu, logger=logger, callbacks=[checkpoint_callback]);
trainer.fit(dck, datamodule=dm)
else:
dm.setup('test')
model = MyModel.load_from_checkpoint(checkpoint_path='trained_model.ckpt')
trainer = pl.Trainer(gpus=1, precision=16, limit_test_batches=0.05)
trainer.test(model=model, datamodule=dm)
模型保存与恢复
autoencoder = LitAutoEncoder()
torch.jit.save(autoencoder.to_torchscript(), "model.pt")
os.path.isfile("model.pt")
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
autoencoder = LitAutoEncoder()
input_sample = torch.randn((1, 28 * 28))
autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True)
os.path.isfile(tmpfile.name)
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath='my/path/', filename='sample-mnist-{epoch:02d}-{val_loss:.2f}')
获取最好的模型
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path
手动保存模型
from collections import deque
import os
self.save_models = deque(maxlen=3)
def manual_save_model(self):
model_path = 'your_model_save_path_%s' % (your_loss)
if len(self.save_models) >= 3:
old_model = self.save_models.popleft()
if os.path.exists(old_model):
os.remove(old_model)
self.trainer.save_checkpoint(model_path)
self.save_models.append(model_path)
model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")
加载Checkpoint
model = MyLightingModule.load_from_checkpoint(PATH)
print(model.learning_rate)
model.eval()
y_hat = model(x)
恢复模型和Trainer
model = LitModel()
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
trainer.fit(model)
training_step
def __init__(self):
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)
loss_a = self.generator(batch)
opt_a.zero_grad()
self.manual_backward(loss_a)
opt_a.step()
loss_b = self.discriminator(batch)
opt_b.zero_grad()
self.manual_backward(loss_b)
opt_b.step()
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
def training_step(self, batch, batch_idx, optimizer_idx):
if optimizer_idx == 0:
if optimizer_idx == 1:
def training_step(self, batch, batch_idx, hiddens):
...
out, hiddens = self.lstm(data, hiddens)
...
return {'loss': loss, 'hiddens': hiddens}
configure_optimizers
def configure_optimizers(self):
opt = Adam(self.parameters(), lr=1e-3)
return opt
def configure_optimizers(self):
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
return generator_opt, disriminator_opt
def configure_optimizers(self):
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
return [generator_opt, disriminator_opt], [discriminator_sched]
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99),
'interval': 'step'}
dis_sched = CosineAnnealing(discriminator_opt, T_max=10)
return [gen_opt, dis_opt], [gen_sched, dis_sched]
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
n_critic = 5
return (
{'optimizer': dis_opt, 'frequency': n_critic},
{'optimizer': gen_opt, 'frequency': 1}
)
callbacks
- https://github.com/miracleyoo/pytorch-lightning-template/blob/master/classification/main.py
import pytorch_lightning.callbacks as plc
def load_callbacks():
callbacks = []
callbacks.append(plc.EarlyStopping(
monitor='val_acc',
mode='max',
patience=10,
min_delta=0.001
))
callbacks.append(plc.ModelCheckpoint(
monitor='val_acc',
filename='best-{epoch:02d}-{val_acc:.3f}',
save_top_k=1,
mode='max',
save_last=True
))
if args.lr_scheduler:
callbacks.append(plc.LearningRateMonitor(
logging_interval='epoch'))
return callbacks
trainer
- https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html