pytorch-lightning 框架初探

项目地址 https://github.com/PyTorchLightning/pytorch-lightning
以下内容整理自项目作者的讲解视频:Converting from PyTorch to PyTorch Lightning (油管视频需梯自备子)

import torch.nn as nn 
import torch  
import torch.optim as optim
import pytorch_lightning as pl

class Net(pl.LightningModule):

    def __init__(self):
        super().__init__()

    def forward(self,x):
        # 可以结合training_step函数,简化forward的内容
        pass

    def loss_func(self, y_hat, y):
        return F.cross_entropy(y_hat, y)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)
    
    def training_step(self, batch, batch_idx):
        x,y = batch #
        y_hat = self(x)
        # return {'loss':F.cross_entropy(y_hat, y)}
        loss = self.loss_func(y_hat, y)
        return {'loss':loss}
        ################################
        # log = {'train_loss':loss}
        # return {'loss':loss, 'log':log}
        # 这样就可以在tensorboard中看到train_loss的曲线

    def log_func(self,):
        # do whatever you want, print, file operation, etc.
        pass

    def validation_step(self, batch, batch_idx):
        # !!! val data 不应该用shuffle
        x,y = batch #
        y_hat = self(x)
        val_loss = self.loss_func(y_hat, y)

        if batch_idx == 0:
            n = x.size(0)
            self.log_func()

        return {'val_loss':val_loss}

    ##############################################################
    ###  这里定义了dataloader fit里就不用通过参数传入了
    ################################
    def train_dataloader(self):
        loader = torch.utils.data.DataLoader()
        return loader

    def val_dataloader(self):
        loader = torch.utils.data.DataLoader()
        return loader

    ################################
    # 使用tensorboard等 logger,  替代validation_step中log_func这一部分
    ################################
    def validation_epoch_end(self, outputs):

        # 计算batch的平均损失,这里的outputs就是validation_step返回的
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        # 也可以传入其他数据,如VAE 重建的图像
        # x_hat = outputs[0]['x_hat']
        # grid = torchvision.utils.make_grid(x_hat)
        # self.logger.experiment 就是 tensorboard SummaryWriter
        self.logger.experiment.add_image('images', grid,0)

        log = {'avg_val_loss':val_loss}
        return {'log':log}
        ################################
        # 如果return的dict中有key='val_loss'会自动出发保存模型
        # return {'val_loss':val_loss}

    
if __name__ == '__main__':

    # dataloader 可以放到module中

    train_loader = torch.utils.data.DataLoader()
    val_loader= torch.utils.data.DataLoader() # shuffle=False
    net =Net()

    # 快速跑完一个train batch和一个dev batch
    # 验证整个流程没错
    trainer = pl.Trainer(fast_dev_run=True) 
    # 完整的训练过程 Trainer() 即可
    # train_percent_check=0.1  只训练0.1的数据
    trainer.fit(net,
                train_dataloader=train_loader,
                val_dataloaders=val_loader
                )

    ################################
    # argparser 的使用

    from argparser import ArgumentParser

    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument('--batch_size', default=32, type=int, help='batch size')
    parser.add_argument('--learning_rate', default=1e-3, type=float)

    args = parser.parse_args()

    net = Net()
    trainer = pl.Trainer.from_argparse_args(args, fast_dev_run=True)
    trainer.fit(net)

    ################################
    # 单GPU训练
    # terminal:  python main.py --gpus 1 --batch_size 256
    # 多GPU训练
    # 默认用DP dataparallel 但用DDP更好 distributed DP
    # terminal:  python main.py --gpus 2 --distributed_backend ddp 

    ################################
    # 16 bit 训练  pytorch 1.6 内建 apex
    # 可能需要修改一定的代码,比如说Loss函数   
    # from F.binary_cross_entropy  to  
    # F.binary_cross_entropy_with_logits(y_hat,y,reduction='sum')

你可能感兴趣的:(pytorch-lightning 框架初探)