Spark MLlib 分布式机器学习并行训练原理 一文读懂

在笔者看来,分布式机器学习训练有三个主要的方案,分别是Spark MLlib,Parameter ServerTensorFlow,倒不是说他们是唯三可供选择的平台,而是因为他们分别代表着三种主流的解决分布式训练方法。今天我们先从Spark MLlib说起,看看最流行的大数据计算平台是如何处理机器学习模型的并行训练问题的。

说起Spark,我想不会有任何算法工程师是陌生的。作为流行了至少五年的大数据项目,虽然受到了诸如Flink等后起之秀的挑战,但其仍是当之无愧的业界最主流的计算平台。而且为了照顾数据处理和模型训练平台的一致性,也有大量公司采用Spark原生的机器学习平台MLlib进行模型训练。选择Spark MLlib作为机器学习分布式训练平台的第一站,不仅因为Spark是流行的,更是因为Spark MLlib的并行训练方法代表着一种朴素的,直观的解决方案。

 

Spark 的分布式计算原理

在介绍 Spark MLlib 的分布式机器学习训练方法之前,让我们先回顾一下 Spark 的分布式计算原理,这是分布式机器学习的基础。

Spark,是一个分布式的计算平台。所谓分布式,指的是计算节点之间不共享内存,需要通过网络通信的方式交换数据。要清楚的是,Spark 最典型的应用方式是建立在大量廉价计算节点上,这些节点可以是廉价主机,也可以是虚拟的 docker container;但这种方式区别于 CPU+GPU 的架构,或者共享内存多处理器的高性能服务器架构。清楚这一点,对于理解后续的 Spark 的计算原理是重要的。

Spark MLlib 分布式机器学习并行训练原理 一文读懂_第1张图片

从图 1 的 Spark 架构图中可以看到,Spark 程序由 Manager node 进行调度组织,由 Worker Node 进行具体的计算任务执行,最终将结果返回给 Drive Program。在物理的 worker node 上,数据还可能分为不同的 partition,可以说 partition 是 spark 的基础处理单元。

在执行具体的程序时,Spark 会将程序拆解成一个任务 DAG(有向无环图),再根据 DAG 决定程序各步骤执行的方法。如图 2 所示,该程序先分别从 textFile 和 HadoopFile 读取文件,经过一些列操作后再进行 join,最终得到处理结果。

Spark MLlib 分布式机器学习并行训练原理 一文读懂_第2张图片

在 Spark 平台上并行处理图 2 的 DAG 时,最关键的过程是找到哪些是可以并行处理的部分,哪些是必须 shuffle 和 reduce 的部分。

这里的 shuffle 指的是所有 partition(数据分片)的数据必须进行洗牌后才能得到下一步的数据,最典型的操作就是图 2 中的 groupByKey 和 join 操作。拿 join 操作来说,必须通过在 textFile 数据中和 hadoopFile 数据中做全量的匹配才可以得到 join 后的 dataframe。而 groupby 操作需要对数据中所有相同的 key 进行合并,也需要全局的 shuffle 才能够完成。

与之相比,map,filter 等操作仅需要逐条的进行数据处理和转换就可以,不需要进行数据间的操作,因此各 partition 之间可以并行处理。

除此之外,在得到最终的计算结果之前,程序需要进行 reduce 的操作,从各 partition 上汇总统计结果,随着 partition 的数量逐渐减小,reduce 操作的并行程度逐渐降低,直到将最终的计算结果汇总到 master 节点上。

所以可以说 shuffle 和 reduce 操作的发生决定了纯并行处理阶段的边界。如图 3 所示,Spark 的 DAG 被分割成了不同的并行处理阶段(stage)。

Spark MLlib 分布式机器学习并行训练原理 一文读懂_第3张图片

 

需要强调的是 shuffle 操作需要在不同计算节点之间进行数据交换,非常消耗计算、通信及存储资源,因此 shuffle 操作是 spark 程序应该尽量避免的。一句话总结 Spark 的计算过程就是:Stage 内部数据高效并行计算,Stage 边界处进行消耗资源的 shuffle 操作或者最终的 reduce 操作。

 

Spark MLlib 并行训练原理

有了 Spark 分布式计算过程的基础,下面就可以更清楚的理解 Spark MLlib 并行训练的原理。

在所有主流的机器学习模型中,Random Forest 的模型结构特点决定了其可以完全进行数据并行的模型训练,而 GBDT 的结构特点则决定了树之间只能进行串行的训练,这里就不再赘述其 spark 的实现方式,我们将重点放在梯度下降类方法的实现上,因为梯度下降的并行程度实现质量直接决定了以 Logistic Regression 为基础,以 Multiple Layer Perceptron 为代表的深度学习模型的训练速度。

这里,我们深入到 Spark MLlib 的源码中,直接把 Spark 做 mini Batch 梯度下降的源码贴出如下(代码摘自 Spark 2.4.3 GradientDescent 类的 runMiniBatchSGD 函数):

Spark MLlib 分布式机器学习并行训练原理 一文读懂_第4张图片

乍一看比较复杂,这里可以为大家做一个精简,只列出关键的操作部分,大家就可以一目了然 Spark 在做什么。

Spark MLlib 分布式机器学习并行训练原理 一文读懂_第5张图片

经过精简的代码非常简单,Spark 的 mini batch 过程制作了三件事:

  1. 把当前的模型参数广播到各个数据 partition(可当作虚拟的计算节点)
  2. 各计算节点进行数据抽样得到 mini batch 的数据,分别计算梯度,再通过 treeAggregate 操作汇总梯度,得到最终梯度 gradientSum
  3. 利用 gradientSum 更新模型权重

这样一来,每次迭代的 Stage 和 Stage 的边界就非常清楚了,Stage 内部的并行部分是各节点分别采样并计算梯度的过程,Stage 的边界是汇总加和各节点梯度的过程。这里再强调一下汇总梯度的操作 treeAggregate,该操作是进行类似树结构的逐层汇总,整个操作流程如图 4 所示。

 

Spark MLlib 分布式机器学习并行训练原理 一文读懂_第6张图片

事实上,treeAggregate 是一次 reduce 操作,本身并不包含 shuffle 操作,再加上采用分层的树形操作,在每层中都是并行执行的,因此整个过程是相对高效的。

在迭代次数达到上限或者模型已经充分收敛后,模型停止训练。这就是 Spark MLlib 进行 mini batch 梯度下降的全过程,也是 Spark MLlib 实现分布式机器学习的最典型代表。

总结来说,Spark MLlib 的并行训练过程其实是 “数据并行” 的过程,并不涉及到过于复杂的梯度更新策略,也没有通过 “参数并行” 的方式实现并行训练。这样的方式简单、直观,易于实现,但也存在着一些局限性。

 

Spark MLlib 并行训练的局限性

虽然 Spark MLlib 基于分布式集群,利用数据并行的方式实现了梯度下降的并行训练,但是有 Spark MLlib 使用经验的同学应该都清楚,使用 Spark MLlib 训练复杂神经网络时,往往力不从心,不仅训练时间过长,而且在模型参数过多时,经常会存在内存溢出的问题。具体来讲,Spark MLlib 的分布式训练方法有下面几个弊端:

1. 采用全局广播的方式,在每轮迭代前广播全部模型参数。众所周知 Spark 的广播过程非常消耗带宽资源,特别是当模型的参数规模过大时,广播过程和在每个节点都维护一个权重参数副本的过程都是极消耗资源的过程,这导致了 Spark 在面对复杂模型时的表现不佳;

2. 采用阻断式的梯度下降方式,每轮梯度下降由最慢的节点决定。从上面的分析可知,Spark MLlib 的 mini batch 的过程是在所有节点计算完各自的梯度之后,逐层 Aggregate 最终汇总生成全局的梯度。也就是说,如果由于数据倾斜等问题导致某个节点计算梯度的时间过长,那么这一过程将 block 其他所有节点无法执行新的任务。这种同步阻断的分布式梯度计算方式,是 Spark MLlib 并行训练效率较低的主要原因;

3.Spark MLlib 并不支持复杂网络结构和大量可调超参。事实上,Spark MLlib 在其标准库里只支持标准的多层感知机神经网络的训练,并不支持 RNN,LSTM 等复杂网络结构,而且也无法选择不同的 activation function 等大量超参。这就导致 Spark MLlib 在支持深度学习方面的能力欠佳。

因为这些原因,如果想寻求更高效的训练速度和更灵活的网络结构,势必需要寻求其他平台的帮助。在这样的情势下,Parameter Server 凭借其高效的分布式训练手段成为分布式机器学习的主流,而 TensorFlow,PyTorch 等深度学习平台则凭借灵活可调整的网络结构,完整的训练、上线支持,成为深度学习平台的主要选择。下两篇内容,本专栏将分别介绍 Patameter Server 和 TensorFlow 的并行训练原理。

例行的问题讨论时间,其他与你一起讨论和分享业界相关的经验:

1. 如果希望在 Spark 上训练深度学习模型,你有没有改进 Spark 的方法?使用第三方 lib?还是修改 Spark 源码?还是自研 Spark 模型?

2. 在训练完成 Spark 模型后,应该使用什么方式将 Spark 模型 deploy 到线上环境,做线上的实时 inference?

你可能感兴趣的:(Spark)