为了更好的了解分布式机器学习,深入理解参数服务器的理念及设计是必要的。结合最近看的李沐大神参数服务器的论文,加深对PS的理解,故整理此文。
Parameter server 使用缩写PS
提出参数服务器框架来解决分布式机器学习问题,数据和计算工作量都分布到Client节点,而服务器节点维护全局共享的参数,这些参数为稀疏向量和矩阵。PS 维护 Client和Server 之间的异步数据通信。支持灵活的一致性模型、弹性可扩展性和容错性。我们提出了挑战非凸和非光滑问题的算法和理论分析。为了验证该框架的可扩展性,我们给出了数十亿参数的真实数据的实验结果。2. 参数服务器设计要点
从上面的参数服务器架构中看出,包含两类节点,每个服务器节点维护全局共享参数的部分区(默认情况下,机器本地参数不同步),它们之间相互通信通过复制、迁移参数以实现可靠性和扩展性。Client节点做计算任务,Server节点做参数的记录保存以及全局聚合,每个client通常会在本地存储一部分训练数据,计算诸如梯度的本地统计数据,client只与server节点进行通信,获取和更新共享参数,可以添加和删除client,这样做需要将训练数据集的适当部分传输到新机器,并查询相应的参数集。 对于不同的算法,参数服务器可以同时支持多个独立的参数向量(即信道)。例如,当服务器可能正在存储一些节点正在主动查询的操作模型的参数时,同时还使用一组不同的client节点来训练新模型以供将来使用时,这很有用。这种方法极大地简化了模型更新和部署,因为这些都只需要由client来切换通道即可。模型应用示例
下面介绍一个简单的模型,
通过分布式次梯度迭代将风险最小化。目的是解决形式的优化问题
上式是一个通用的带正则化项的损失函数,其中n是样本总数,l(x,y,w)是计算单个样本的损失函数,x是特征向量,y是样本label,w是模型参数。那么模型的训练目标就是使损失函数F(w)最小。为了求解arg (min F(w)),往往使用梯度下降的方法,那么Parameter Server的主要目的就是分布式并行进行梯度下降的计算完成参数的更新与最终收敛。需要注意的是,由于公式中正则化项的存在需要汇总所有模型参数才能够正确计算,因此较难进行模型参数的并行训练,因此Parameter Server采取了和Spark MLlib一样的数据并行训练产生局部梯度,再汇总梯度更新参数权重的并行化训练方案。训练过程
具体来讲,下图以伪码方式列出了Parameter Server并行梯度下降的主要步骤:可以看到Parameter Server由server节点和worker节点组成,其主要功能分别如下:
server节点的主要功能是保存模型参数、接受worker节点计算出的局部梯度、汇总计算全局梯度,并更新模型参数
worker节点的主要功能是各保存部分训练数据,从server节点拉取最新的模型参数,根据训练数据计算局部梯度,push给server节点。
在物理架构上,PS其实是和spark的master-worker的架构基本一致的,如下图:
PS的物理架构 可以看到,PS分为两大部分:server group和多个worker group,另外resource manager负责总体的资源分配调度。通过上面的介绍可以清楚的知道PS的并行梯度下降流程,其中最关键的两个操作就是push和pull:
push:worker节点利用本节点上的训练数据,计算好局部梯度,上传给server节点;
pull:为了进行下一轮的梯度计算,worker节点从server节点拉取最新的模型参数到本地。
上面的操作,也是在论文中介绍的接口定义,现有的机器学习平台实现也使用了这样的定义。
下面通过一个Angel的一个示例了解整个PS的分布式训练流程每个worker载入一部分训练数据
worker节点从server节点pull最新的模型参数
worker节点利用本节点数据计算梯度
worker节点将梯度push到server节点
server节点汇总梯度更新模型
goto step2 直到迭代次数上限或模型收敛
Server节点的协同和效率问题
导致Spark MLlib并行训练效率低下的另一原因是每次迭代都需要master节点将模型权重参数的广播发送到各worker节点。这导致两个问题:PS server节点组成的一致性哈希环
PS的server group中应用一致性哈希的原理大致有如下几步:
将模型参数的key映射到一个环形的hash空间,比如有一个hash函数可以将任意key映射到0~(2^32)-1的hash空间内,我们只要让(2^32)-1这个桶的下一个桶是0这个桶,那么这个空间就变成了一个环形hash空间;
根据server节点的数量n,将环形hash空间等分成n*m个range,让每个server间隔地分配m个hash range。这样做的目的是保证一定的负载均衡性,避免hash值过于集中带来的server负载不均;
在新加入一个server节点时,让新加入的server节点找到hash环上的插入点,让新的server负责插入点到下一个插入点之间的hash range,这样做相当于把原来的某段hash range分成两份,新的节点负责后半段,原来的节点负责前半段。这样不会影响其他hash range的hash分配,自然不存在大量的rehash带来的数据大混洗的问题。
删除一个server节点时,移除该节点相关的插入点,让临近节点负责该节点的hash range。
PS server group中应用一致性哈希原理,其实非常有效的降低了原来单master节点带来的瓶颈问题。比如现在某worker节点希望pull新的模型参数到本地,worker节点将发送不同的range pull到不同的server节点,server节点可以并行的发送自己负责的weight到worker节点。
此外,由于在处理梯度的过程中server节点之间也可以高效协同,某worker节点在计算好自己的梯度后,也只需要利用range push把梯度发送给一部分相关的server节点即可。当然,这一过程也与模型结构相关,需要跟模型本身的实现结合起来实现。总的来说,PS基于一致性哈希提供了range pull和range push的能力,让模型并行训练的实现更加灵活。
3. 总结
参数服务器就类似于MapReduce,是大规模机器学习在不断使用过程中,抽象出来的框架之一。重点支持的就是训练数据、参数的分布式,毕竟巨大的模型其实就是巨大的参数。PS成为TensorFlow、MXNet等框架的核心组件,Angel更是以PS为基础衍生出自己的一套生态平台。后面继续从每个核心功能点深入分析。
相关文章:
1. Angel基于参数服务器的规模分布式机器学习平台
2.Angel分布式机器学习平台—LR算法示例
3.Angel中的损失函数详解
4.一文彻底搞懂Angel机器学习平台