(给数据分析与开发加星标,提升数据技能)
一直以来,PyTorch就以简单又好用的特点,广受AI研究者的喜爱。 但是,一旦任务复杂化,就可能会发生一系列错误,花费的时间更长。 于是,就诞生了这样一个“友好”的PyTorch Lightning。 直接在GitHub上斩获6.6k星。 首先,它把研究代码与工程代码相分离,还将PyTorch代码结构化,更加直观的展现数据操作过程。 这样,更加易于理解,不易出错,本来很冗长的代码一下子就变得轻便了,对AI研究者十分的友好。 话不多说,我们就来看看这个轻量版的“PyTorch”。来源:量子位
l1 = nn.Linear(...)
l2 = nn.Linear(...)
decoder = Decoder()
x1 = l1(x)
x2 = l2(x2)
out = decoder(features, x)
loss = perceptual_loss(x1, x2, x) + CE(out, x)
而工程代码是与培训此系统相关的所有代码,比如提前停止、通过GPU分配、16位精度等。 我们知道,这些代码在大多数项目中都相同,所以在这里,直接由Trainer抽象出来。
model.cuda(0)
x = x.cuda(0)
distributed = DistributedParallel(model)
with gpu_zero:
download_data()
dist.barrier()
剩下的就是非必要代码,有助于研究项目,但是与研究项目无关,可能是检查梯度、记录到张量板。此代码由Callbacks抽象出来。
# log samples
z = Q.rsample()
generated = decoder(z)
self.experiment.log('images', generated)
此外,它还有一些的附加功能,比如你可以在CPU,GPU,多个GPU或TPU上训练模型,而无需更改PyTorch代码的一行;你可以进行16位精度训练,可以使用Tensorboard的五种方式进行记录。 这样说,可能不太明显,我们就来直观的比较一下PyTorch与PyTorch Lightning之间的差别吧。
conda activate my_env
pip install pytorch-lightning
或在没有conda环境的情况下,可以在任何地方使用pip。 代码如下:
pip install pytorch-lightning
https://github.com/PyTorchLightning/pytorch-lightning
https://pytorch-lightning.readthedocs.io/en/latest/index.html
创建者个人网站:
https://www.williamfalcon.com/
- EOF -
推荐阅读 点击标题可跳转1、万字综述,核心开发者全面解读 PyTorch 内部机制
2、手把手教你用 PyTorch 快速准确地建立神经网络
3、高性能 PyTorch 是如何炼成的?过来人吐血整理的 10 条避坑指南
看完本文有收获?请转发分享给更多人
关注「数据分析与开发」加星标,提升数据技能
好文章,我在看❤️