用PyTorch对Leela Zero进行神经网络训练

作者|Peter Yu
编译|Flin
来源|towardsdatascience

用PyTorch对Leela Zero进行神经网络训练_第1张图片

最近,我一直在寻找方法来加快我的研究和管理我的实验,特别是围绕着写训练管道和管理实验配置文件这两个方面,我发现这两个新项目叫做PyTorch Lightning和Hydra。PyTorch Lightning可以帮助你快速编写训练管道,而Hydra可以帮助你有效地管理配置文件。

为了练习使用它们,我决定为Leela Zero(https://github.com/leela-zero... 编写一个训练管道。我这样做,是因为这是一个范围很广的项目,涉及到使用多个gpu在大数据集上训练大型网络,可以说是一个十分有趣的技术挑战。此外,我以前曾经实现过一个更小版本的AlphaGo国际象棋(https://medium.com/@peterkeun...) ,所以我认为这将是一个有趣的业余项目。

在这个博客中,我将解释这个项目的主要细节,以便你能够轻松理解我所做的工作。你可以在这里阅读我的代码:https://github.com/yukw777/le...

Leela Zero

第一步是找出Leela Zero神经网络的内部工作原理。我大量引用了Leela Zero的文档和它的Tensorflow训练管道。

神经网络结构

Leela Zero的神经网络由一个残差塔(ResNet “tower” )组成,塔上有两个“head”,即AlphaGo Zero论文(https://deepmind.com/blog/art...) 中描述的负责策略的“头”(policy head)和负责计算价值的“头”(value head)。就像论文所述,策略“头”和值“头”开始的那几个卷积滤波器都是1x1,其他所有的卷积滤波器都是3x3。游戏和棋盘特征被编码为[批次大小,棋盘宽度,棋盘高度,特征数量]形状的张量,首先通过残差塔输入。然后,塔提取出抽象的特征,并通过每个“头”输入这些特征,以计算下一步棋的策略概率分布和游戏的价值,从而预测游戏的获胜者。

你可以在下面的代码片段中找到网络的实现细节。

用PyTorch对Leela Zero进行神经网络训练_第2张图片

权重格式

Leela Zero使用一个简单的文本文件来保存和加载网络权重。文本文件中的每一行都有一系列数字,这些数字表示网络的每一层的权重。首先是残差塔,然后是策略头,然后是值头。

卷积层有2个权重行:

  1. 与[output, input, filter size, filter size]形状的卷积权值
  2. 通道的偏差

Batchnorm层有2个权重行:

  1. Batchnorm平均值
  2. Batchnorm方差

内积(完全连接)层有2个权重行:

  1. 带有[output, input]形状的层权重
  2. 输出偏差

我编写了单元测试来确保我的权重文件是正确的。我使用的另一个简单的完整性检查是计算层的数量,在加载我的权值文件后,将其与Leela Zero进行比较。层数公式为:

n_layers = 1 (version number) +
           2 (input convolution) + 
           2 (input batch norm) + 
           n_res (number of residual blocks) * 
           8 (first conv + first batch norm + 
              second conv + second batch norm) + 
           2 (policy head convolution) + 
           2 (policy head batch norm) + 
           2 (policy head linear) + 
           2 (value head convolution) + 
           2 (value head batch norm) + 
           2 (value head first linear) + 
           2 (value head second linear)

到目前为止,这看起来很简单,但是你需要注意一个实现细节。Leela Zero实际上使用卷积层的偏差来表示下一个归一化层(batch norm)的可学习参数(gammabeta)。这样做是为了使权值文件的格式(只有一行表示层权值,另一行表示偏差)在添加归一化层时不必更改。

目前,Leela Zero只使用归一化层的beta项,将gamma设置为1。那么,实际上我们该如何使用卷积偏差,来产生与在归一化层中应用可学习参数相同的结果呢?我们先来看看归一化层的方程:

y = gamma * (x — mean)/sqrt(var — eps) + beta

由于Leela Zero将gamma设为1,则方程为:

y = (x — mean)/sqrt(var — eps) + beta

现在,设定x_conv是没有偏差的卷积层的输出。然后,我们想给x_conv添加一些偏差,这样当你在没有beta的归一化层中运行它时,结果与在只有beta的归一化层方程中运行x_conv是一样的:

(x_conv + bias — mean)/sqrt(var — eps) = 
(x_conv — mean)/sqrt(var — eps) + beta 
x_conv + bias — mean = 
x_conv — mean + beta * sqrt(var — eps) 
bias = beta * sqrt(var — eps)

因此,如果我们在权值文件中将卷积偏差设置为beta * sqrt(var - eps),我们就会得到期望的输出,这就是LeelaZero所做的。

那么,我们如何实现它呢?在Tensorflow中,你可以通过调用tf.layers.batch_normalization(scale=False)来告诉归一化层要忽略gamma项,然后使用它。

遗憾的是,在PyTorch中,你不能将归一化层设置为只忽略gamma,你只能通过将仿射参数设置为False: BatchNorm2d(out_channels, affine=False),来忽略gammabeta。所以,我把归一化层设为两个都忽略,然后简单地在后面加上一个张量,它表示beta。然后,使用公式bias = beta * sqrt(var - eps)来计算权值文件的卷积偏差。

训练管道

在弄清了Leela Zeros的神经网络的细节之后,就到了处理训练管道的时候了。正如我提到的,我想练习使用两个工具:PyTorch Lightning和Hydra,来加快编写训练管道和有效管理实验配置。让我们来详细了解一下我是如何使用它们的。

PyTorch Lightning

编写训练管道是我研究中最不喜欢的部分:它涉及大量重复的样板代码,而且很难调试。正因为如此,PyTorch Lightning对我来说就像一股清流,它是一个轻量级的库,PyTorch没有很多辅助抽象,在编写训练管道时,它负责处理大部分样板代码。它允许你关注你的训练管道中更有趣的部分,比如模型架构,并使你的研究代码更加模块化和可调试。此外,它还支持多gpu和TPU的开箱即用训练!

为了使用PyTorch Lightning作为我的训练管道,我需要做的最多的编码就是编写一个类,我称之为NetworkLightningModule,它继承自LightningModule来指定训练管道的细节,并将其传递给训练器。有关如何编写自己的LightningModule的详细信息,可以参考PyTorch Lightning的官方文档。

Hydra

我一直在研究的另一部分是实验管理。当你进行研究的时候,你不可避免地要运行大量不同的实验来测试你的假设,所以,以一种可扩展的方式跟踪它们是非常重要的。到目前为止,我一直依赖于配置文件来管理我的实验版本,但是使用平面配置文件很快就变得难以管理。使用模板是这个问题的一个解决方案。然而,我发现模板最终也会变得混乱,因为当你覆盖多个层的值文件来呈现你的配置文件时,很难跟踪哪个值来自哪个值文件。

另一方面,Hydra是一个基于组件的配置管理系统。与使用单独的模板和值文件来呈现最终配置不同,你可以组合多个较小的配置文件来组成最终配置。它不如基于模板的配置管理系统灵活,但我发现基于组件的系统在灵活性和可维护性之间取得了很好的平衡。Hydra就是这样一个专门为研究脚本量身定做的系统。它的调用有点笨拙,因为它要求你将它用作脚本的主要入口点,但实际上我认为有了这种设计,它很容易与你的训练脚本集成。此外,它允许你通过命令行手动覆盖配置,这在运行你的实验的不同版本时非常有用。我常常使用Hydra管理不同规模的网络架构和训练管道配置。

评估

为了评估我的训练网络,我使用GoMill(https://github.com/mattheww/g...) 来举行围棋比赛。它是一个运行在Go Text Protocol (GTP)引擎上的比赛的库,Leela Zero就是其中之一。你可以在这里(https://github.com/yukw777/le...) 找到我使用的比赛配置。

结论

通过使用PyTorch-Lightning和Hydra,能够极大地加快编写训练管道的速度,并有效地管理实验配置文件。我希望这个项目和博客文章也能对你的研究有所帮助。你可以在这里查看代码:https://github.com/yukw777/le...

原文链接:https://towardsdatascience.co...

欢迎关注磐创AI博客站:
http://panchuang.net/

sklearn机器学习中文官方文档:
http://sklearn123.com/

欢迎关注磐创博客资源汇总站:
http://docs.panchuang.net/

你可能感兴趣的:(人工智能)