Parameter Server 分布式机器学习训练原理 一文读懂

上篇文章对 Spark MLlib 的并行训练方法做了详细的介绍(https://blog.csdn.net/sinat_26811377/article/details/100763590),如文章所说,Spark 采取了简单直观的数据并行的方法解决模型并行训练的问题,但由于 Spark 的并行梯度下降方法是同步阻断式的,且模型参数需通过全局广播的形式发送到各节点,因此 Spark 的并行梯度下降是相对低效的。

为了解决相应的问题,2014 年分布式可扩展的 Parameter Server 被李沐提出,几乎完美的解决了机器模型的分布式训练问题,时至今日,parameter server 不仅被直接应用在各大公司的机器学习平台上,而且也被集成在 TensorFlow,MXNet 等主流的深度框架中,作为机器学习分布式训练最重要的解决方案。

 

Parameter Server 的分布式训练原理

第一部分我们首先聚焦 PS 进行分布式训练的基本原理。这里以通用的机器学习问题为例。

Parameter Server 分布式机器学习训练原理 一文读懂_第1张图片

上式是一个通用的带正则化项的损失函数,其中 n 是样本总数,l(x,y,w) 是计算单个样本的损失函数,x 是特征向量,y 是样本 label,w 是模型参数。那么模型的训练目标就是使损失函数 F(w) 最小。为了求解 arg (min F(w)),往往使用梯度下降的方法,那么 Parameter Server 的主要目的就是分布式并行进行梯度下降的计算完成参数的更新与最终收敛。需要注意的是,由于公式中正则化项的存在需要汇总所有模型参数才能够正确计算,因此较难进行模型参数的并行训练,因此 Parameter Server 采取了和 Spark MLlib 一样的数据并行训练产生局部梯度,再汇总梯度更新参数权重的并行化训练方案。

具体来讲,图 1 以伪码方式列出了 Parameter Server 并行梯度下降的主要步骤:

Parameter Server 分布式机器学习训练原理 一文读懂_第2张图片

可以看到 Parameter Server 由 server 节点和 worker 节点组成,其主要功能分别如下:

server 节点的主要功能是保存模型参数、接受 worker 节点计算出的局部梯度、汇总计算全局梯度,并更新模型参数

worker 节点的主要功能是各保存部分训练数据,从 server 节点拉取最新的模型参数,根据训练数据计算局部梯度,上传给 server 节点。

在物理架构上,PS 其实是和 spark 的 master-worker 的架构基本一致的,具体如图 2

Parameter Server 分布式机器学习训练原理 一文读懂_第3张图片

可以看到,PS 分为两大部分:server group 和多个 worker group,另外 resource manager 负责总体的资源分配调度。

server group 内部包含多个 server node,每个 server node 负责维护一部分参数,server manager 负责维护和分配 server 资源;

每个 worker group 对应一个 application(即一个模型训练任务),worker group 之间,以及 worker group 内部的 worker node 互相之间并不通信,worker node 只与 server 通信。

结合 PS 的物理架构,PS 的并行训练整体示意图如图 3:

Parameter Server 分布式机器学习训练原理 一文读懂_第4张图片

图 3 结合图 2 描述的并行梯度下降方法的伪码以及图 2 的 PS 物理架构,清晰的描述了 PS 的并行梯度下降流程,其中最关键的两个操作就是 push 和 pull:

push:worker 节点利用本节点上的训练数据,计算好局部梯度,上传给 server 节点;

pull:为了进行下一轮的梯度计算,worker 节点从 server 节点拉取最新的模型参数到本地。

结合图 3 这里概括一下整个 PS 的分布式训练流程:

  1. 每个 worker 载入一部分训练数据
  2. worker 节点从 server 节点 pull 最新的全部模型参数
  3. worker 节点利用本节点数据计算梯度
  4. worker 节点将梯度 push 到 server 节点
  5. server 节点汇总梯度更新模型
  6. goto step2 直到迭代次数上限或模型收敛

 

一致性与并行效率之间的取舍

在上篇文章介绍 spark 的并行梯度下降原理时,曾经提到 spark 并行梯度下降效率较低的原因就是每个节点都需要等待其他所有节点的梯度都计算完后,master 节点汇总梯度,计算好新的模型参数后,才能开始下一轮的梯度计算,我们称这种方式为 “同步阻断式” 的并行梯度下降过程。

同步阻断式 “的并行梯度下降虽然是严格意义上的一致性最强的梯度下降方法,因为其计算结果和串行计算的过程一直,但效率过低,各节点的 waiting 时间过长,有没有办法提高梯度下降的并行度呢?

PS 采取的方法是用 “异步非阻断式” 的梯度下降替代原来的同步式方法。图 4 是一个 worker 节点多次迭代计算梯度的过程,可以看到节点在做第 11 次迭代(iter 11)计算时,第 10 次迭代后的 push&pull 过程并没有结束,也就是说最新的模型权重参数还没有被拉取到本地,该节点仍使用的是 iter 10 的权重参数计算的 iter 11 的梯度。这就是所谓的异步非阻断式梯度下降方法,其他节点计算梯度的进度不会影响本节点的梯度计算。所有节点始终都在并行工作,不会被其他节点阻断。

Parameter Server 分布式机器学习训练原理 一文读懂_第5张图片

用下面转载了异步更新和同步更新的两个动画,大家可以非常直观的了解异步更新和同步更新的过程和区别。

Parameter Server 分布式机器学习训练原理 一文读懂_第6张图片

Parameter Server 分布式机器学习训练原理 一文读懂_第7张图片

当然,任何的技术方案都是取舍,异步梯度更新的方式虽然大幅加快了训练速度,但带来的是模型一致性的丧失,也就是说并行训练的结果与原来的单点串行训练的结果是不一致的,这样的不一致会对模型收敛的速度造成一定影响。所以最终选取同步更新还是异步更新取决于不同模型对于一致性的敏感程度。这类似于一个模型超参数选取的问题,需要针对具体问题进行具体的验证。

除此之外,在同步和异步之间,还可以通过一些 “最大延迟” 等参数来限制异步的程度。比如可以限定在三轮迭代之内,模型参数必须更新一次,那么如果某 worker 节点计算了三轮梯度,该节点还未完成一次从 server 节点 pull 最新模型参数的过程,那么该 worker 节点就必须停下等待 pull 操作的完成。这是同步和异步之间的折衷方法。

在 PS 论文的原文中也提供了异步和同步更新的效率对比,这里可以作为参考(基于 Sparse logistic regression 模型训练)。

Parameter Server 分布式机器学习训练原理 一文读懂_第8张图片

SystemA 和 B 都是同步更新梯度的系统,PS 是异步更新的策略,可以看到 PS 的 computing 占比远高于同步更新策略

Parameter Server 分布式机器学习训练原理 一文读懂_第9张图片

可以看到异步更新的 PS 的收敛速度也远胜于同步更新的 SystemA 和 B,这证明异步更新带来的梯度不一致性的影响没有想象中那么大。

 

多 server 节点的协同和效率问题

导致 Spark MLlib 并行训练效率低下的另一原因是每次迭代都需要 master 节点将模型权重参数的广播发送到各 worker 节点。这导致两个问题:

1.master 节点作为一个瓶颈节点,受带宽条件的制约,发送全部模型参数的效率不高;

2. 同步地广播发送所有权重参数,使系统整体的网络负载非常大。

那么 PS 是如何解决单点 master 效率低下的问题呢?从图 2 的架构图中可知,PS 采用了 server group 内多 server 的架构,每个 server 主要负责一部分的模型参数。模型参数使用 key value 的形式,每个 server 负责一个 key 的 range 就可以了。

那么另一个问题来了,每个 server 是如何决定自己负责哪部分 key range 呢?如果有新的 server 节点加入,又是如何在保证已有 key range 不发生大的变化的情况下加入新的节点呢?这两个问题的答案涉及到一致性哈希(consistent hashing)的原理。

Parameter Server 分布式机器学习训练原理 一文读懂_第10张图片

PS 的 server group 中应用一致性哈希的原理大致有如下几步:

1. 将模型参数的 key 映射到一个环形的 hash 空间,比如有一个 hash 函数可以将任意 key 映射到 0~(2^32)-1 的 hash 空间内,我们只要让 (2^32)-1 这个桶的下一个桶是 0 这个桶,那么这个空间就变成了一个环形 hash 空间;

2. 根据 server 节点的数量 n,将环形 hash 空间等分成 n*m 个 range,让每个 server 间隔地分配 m 个 hash range。这样做的目的是保证一定的负载均衡性,避免 hash 值过于集中带来的 server 负载不均;

3. 在新加入一个 server 节点时,让新加入的 server 节点找到 hash 环上的插入点,让新的 server 负责插入点到下一个插入点之间的 hash range,这样做相当于把原来的某段 hash range 分成两份,新的节点负责后半段,原来的节点负责前半段。这样不会影响其他 hash range 的 hash 分配,自然不存在大量的 rehash 带来的数据大混洗的问题。

4. 删除一个 server 节点时,移除该节点相关的插入点,让临近节点负责该节点的 hash range。

PS server group 中应用一致性哈希原理,其实非常有效的降低了原来单 master 节点带来的瓶颈问题。比如现在某 worker 节点希望 pull 新的模型参数到本地,worker 节点将发送不同的 range pull 到不同的 server 节点,server 节点可以并行的发送自己负责的 weight 到 worker 节点。

此外,由于在处理梯度的过程中 server 节点之间也可以高效协同,某 worker 节点在计算好自己的梯度后,也只需要利用 range push 把梯度发送给一部分相关的 server 节点即可。当然,这一过程也与模型结构相关,需要跟模型本身的实现结合起来实现。总的来说,PS 基于一致性哈希提供了 range pull 和 range push 的能力,让模型并行训练的实现更加灵活。

 

Parameter Server 的技术要点总结

总结一下 Parameter Server 实现分布式机器学习模型训练的要点:

  1. 用异步非阻断式的分布式梯度下降策略替代同步阻断式的梯度下降策略;
  2. 实现多 server 节点的架构,避免了单 master 节点带来的带宽瓶颈和内存瓶颈;
  3. 使用一致性哈希,range pull 和 range push 等工程手段实现信息的最小传递,避免广播操作带来的全局性网络阻塞和带宽浪费。

但是大家要清楚的是,Parameter Server 仅仅是一个管理并行训练梯度的权重的平台,并不涉及到具体的模型实现,因此 PS 往往是作为 MXNet,TensorFlow 的一个组件,要想具体实现一个机器学习模型,还需要依赖于通用的,综合性的机器学习平台。那么下一篇文章,我们就来介绍一下以 TensorFlow 为代表的机器学习平台的工作原理,特别是并行训练的原理。

你可能感兴趣的:(算法,Algorithm)