用ADMM实现统计学习问题的分布式计算

用ADMM实现统计学习问题的分布式计算

最近研读了 Boyd 2011 年那篇关于 ADMM 的综述。这是由两件事情促成的:一是刘鹏的《计算广告》中 CTR 预测技术那部分提到了用 ADMM 框架来并行化 LR 问题的求解;二是我的一个本科同学在读博士期间做的东西恰好用到了 ADMM,我们曾经谈到过这个框架。我从这篇综述里整理出了一个条思路,顺着这个思路看下去,就能对 ADMM 原理和应用有个大概的了解。因此,此文可以当做 ADMM 的快速入门。

交替方向乘子法(Alternating Direction Method of Multipliers,ADMM)是一种求解优化问题的计算框架, 适用于求解分布式凸优化问题,特别是统计学习问题。 ADMM 通过分解协调(Decomposition-Coordination)过程,将大的全局问题分解为多个较小、较容易求解的局部子问题,并通过协调子问题的解而得到大的全局问题的解。

ADMM 最早分别由 Glowinski & Marrocco 及 Gabay & Mercier 于 1975 年和 1976 年提出,并被 Boyd 等人于 2011 年重新综述并证明其适用于大规模分布式优化问题。由于 ADMM 的提出早于大规模分布式计算系统和大规模优化问题的出现,所以在 2011 年以前,这种方法并不广为人知。

ADMM 计算框架

一般问题

若优化问题可表示为

minf(x)+g(z)s.t.Ax+Bz=c(1)(1)minf(x)+g(z)s.t.Ax+Bz=c

其中 xRs,zRn,ARp×s,BRp×n,cRp,f:RsR,g:RnRx∈Rs,z∈Rn,A∈Rp×s,B∈Rp×n,c∈Rp,f:Rs→R,g:Rn→R。 xx 与zz 是优化变量;f(x)+g(z)f(x)+g(z) 是待最小化的目标函数(Objective Function),它由与变量 xx 相关的 f(x)f(x) 和与变量 xx 相关的 g(z)g(z) 这两部分构成,这种结构可以很容易地处理统计学习问题优化目标中的正则化项; Ax+Bz=cAx+Bz=c 是 pp 个等式约束条件(Equality Constraints)的合写。其增广拉格朗日函数(Augmented Lagrangian)为

Lρ(x,z,y)=f(x)+g(z)+yT(Ax+Bzc)+(ρ/2)Ax+Bzc22Lρ(x,z,y)=f(x)+g(z)+yT(Ax+Bz−c)+(ρ/2)∥Ax+Bz−c∥22

其中 yy 是对偶变量(或称为拉格朗日乘子), ρ>0ρ>0 是惩罚参数。 Lρ 名称中的“增广”是指其中加入了二次惩罚项 (ρ/2)Ax+Bzc22(ρ/2)∥Ax+Bz−c∥22 。

则该优化问题的 ADMM 迭代求解方法为

xk+1:=argminxLρ(x,zk,yk)zk+1:=argminzLρ(xk+1,z,yk)yk+1:=yk+ρ(Axk+1+Bzk+1c)xk+1:=arg⁡minxLρ(x,zk,yk)zk+1:=arg⁡minzLρ(xk+1,z,yk)yk+1:=yk+ρ(Axk+1+Bzk+1−c)

令 u=(1/ρ)yu=(1/ρ)y ,并对 Ax+BzcAx+Bz−c 配方,可得表示上更简洁的缩放形式(Scaled Form)

xk+1:=argminx(f(x)+(ρ/2)Ax+Bzkc+uk22)zk+1:=argminz(g(z)+(ρ/2)Axk+1+Bzc+uk22)uk+1:=uk+Axk+1+Bzk+1cxk+1:=arg⁡minx(f(x)+(ρ/2)∥Ax+Bzk−c+uk∥22)zk+1:=arg⁡minz(g(z)+(ρ/2)∥Axk+1+Bz−c+uk∥22)uk+1:=uk+Axk+1+Bzk+1−c

可以看出,每次迭代分为三步:

  1. 求解与 xx 相关的最小化问题,更新变量 xx
  2. 求解与 zz 相关的最小化问题,更新变量 zz
  3. 更新对偶变量 uu

ADMM名称中的“乘子法”是指这是一种使用增广拉格朗日函数(带有二次惩罚项)的对偶上升(Dual Ascent)方法,而“交替方向”是指变量 xx 和 zz 是交替更新的。两变量的交替更新是在 f(x)f(x) 或 g(z)g(z) 可分时可以将优化问题分解的关键原因。

收敛性

可以证明,当满足条件

  1. 函数 f,gf,g 具有 closed, proper, convex 的性质
  2. 拉格朗日函数 L0L0 有鞍点

时,ADMM 的迭代收敛(当 kk→∞ 时, rk0,f(xk)+g(zk)p,ykyrk→0,f(xk)+g(zk)→p⋆,yk→y⋆ )。这样的收敛条件比没有使用增广拉格朗日函数的对偶上升法的收敛条件宽松了不少。

在高精度要求下,ADMM 的收敛很慢;但在中等精度要求下,ADMM 的收敛速度可以接受(几十次迭代)。因此 ADMM 框架尤其适用于不要求高精度的优化问题,这恰好是大规模统计学习问题的特点。

一致性(Consensus)问题

一类可用 ADMM 框架解决的特殊优化问题是一致性(Consensus)问题,其形式为

minNi=1fi(z)+g(z)min∑i=1Nfi(z)+g(z)

将加性优化目标 Ni=1fi(z)∑i=1Nfi(z) 转化为可分优化目标 Ni=1fi(xi)∑i=1Nfi(xi) ,并增加相应的等式约束条件,可得其等价问题

minNi=1fi(xi)+g(z)s.t.xiz=0,i=1,,N(2)(2)min∑i=1Nfi(xi)+g(z)s.t.xi−z=0,i=1,…,N

这里约束条件要求每个子目标中的局部变量 xixi 与全局变量 zz 一致,因此该问题被称为一致性问题。

可以看出,令式(1)中的x=(xT1,,xTN)T,f(x)=Ni=1fi(xi),A=IsN,B=(Is,,IsN)T,c=0x=(x1T,…,xNT)T,f(x)=∑i=1Nfi(xi),A=IsN,B=−(Is,…,Is⏟N)T,c=0 ,即得到式(2)。因此 Consensus 问题可用 ADMM 框架求解,其迭代方法为

xk+1i:=argminxi(fi(xi)+(ρ/2)xizk+uki22)zk+1:=argminz(g(z)+(Nρ/2)z¯¯¯xk+1¯¯¯uk22)uk+1i:=uki+xk+1izk+1xik+1:=arg⁡minxi(fi(xi)+(ρ/2)∥xi−zk+uik∥22)zk+1:=arg⁡minz(g(z)+(Nρ/2)∥z−x¯k+1−u¯k∥22)uik+1:=uik+xik+1−zk+1

其中 ¯¯¯x=(1/N)Ni=1xi,¯¯¯u=(1/N)Ni=1uix¯=(1/N)∑i=1Nxi,u¯=(1/N)∑i=1Nui 。

可以看出,变量 xx 和对偶变量 uu 的更新都是可以采用分布式计算的。只有在更新变量 zz 时,需要收集 xx 和 uu 分布式计算的结果,进行集中式计算。

统计学习问题应用

统计学习问题也是模型拟合问题,可表示为

minl(D,d,z)+r(z)minl(D,d,z)+r(z)

其中 zRnz∈Rn 是待学习的参数, DRm×nD∈Rm×n 是模型的输入数据集, dRmd∈Rm 是模型的输出数据集, l:Rm×n×Rm×RnRl:Rm×n×Rm×Rn→R 是损失函数, r:RnRr:Rn→R 是正则化项, mm表示数据的个数, nn 表示特征的个数。

对于带L1正则化项的线性回归(Lasso),其平方损失函数为

l(D,d,z)=(1/2)Dzd22l(D,d,z)=(1/2)∥Dz−d∥22

对于逻辑回归(Logistic Regression),其极大似然损失函数为

l(D,d,z)=1T(log(exp(Dz)+1)DzdT)l(D,d,z)=1T(log⁡(exp⁡(Dz)+1)−DzdT)

对于线性支持向量机(Linear Support Vector Machine),其合页(Hinge)损失函数为

l(D,d,z)=1T(1DzdT)+l(D,d,z)=1T(1−DzdT)+

将训练数据集(输入数据和输出数据)在样本的维度( mm )划分成 NN 块

D=⎜⎜D1DN⎟⎟,d=⎜⎜d1dN⎟⎟,D=(D1⋮DN),d=(d1⋮dN),

其中 DiRmi×n,diRmi,Ni=1mi=mDi∈Rmi×n,di∈Rmi,∑i=1Nmi=m ,若有局部损失函数 li:Rmi×n×Rmi×RnRli:Rmi×n×Rmi×Rn→R ,可得

minNi=1li(Di,di,xi)+r(z)s.t.xiz=0,i=1,,N(3)(3)min∑i=1Nli(Di,di,xi)+r(z)s.t.xi−z=0,i=1,…,N

可以看出,令式(2)中的 fi(xi)=li(Di,di,xi),g(z)=r(z)fi(xi)=li(Di,di,xi),g(z)=r(z) ,即得到式(3),因此 统计学习问题可用 Consensus ADMM 实现分布式计算,其迭代方法为

xk+1i:=argminxi(li(Di,di,xi)+(ρ/2)xizk+uki22)zk+1:=argminz(r(z)+(Nρ/2)z¯¯¯xk+1¯¯¯uk22)uk+1i:=uki+xk+1izk+1xik+1:=arg⁡minxi(li(Di,di,xi)+(ρ/2)∥xi−zk+uik∥22)zk+1:=arg⁡minz(r(z)+(Nρ/2)∥z−x¯k+1−u¯k∥22)uik+1:=uik+xik+1−zk+1

分布式实现

MPI

MPI 是一个语言无关的并行算法消息传递规约。使用 MPI 范式的 Consensus ADMM 算法如下所示。

  1. Initialize NN processes, along with xi,ui,ri,zxi,ui,ri,z

  2. Repeat

  3.     Update ri=xizri=xi−z

  4.     Update ui:=ui+xizui:=ui+xi−z

  5.     Update xi:=argminx(fi(x)+(ρ/2)xz+ui22)xi:=arg⁡minx(fi(x)+(ρ/2)∥x−z+ui∥22)

  6.     Let w:=xi+uiw:=xi+ui and t:=ri22t:=∥ri∥22

  7.     Allreduce ww and tt

  8.     Let zprev:=zzprev:=z and update z:=argminz(g(z)+(Nρ/2)|zw/>N|22)z:=arg⁡minz(g(z)+(Nρ/2)|z−w/>N|22)

  9. Until ρNzzprev2ϵconvρN∥z−zprev∥2≤ϵconv and tϵfeast≤ϵfeas

该算法中假设有 NN 个处理器,每个处理器都运行同样的程序,只是处理的数据不同。第6步中的 Allreduce 是 MPI 中定义的操作,表示对相应的局部变量进行全局操作(如这里的求和操作),并将结果更新到每一个处理器。

MapReduce

MapReduce 是一个在工业界和学术界都很流行的分布式批处理编程模型。使用 MapReduce 范式的 Consensus ADMM 算法(一次迭代)如下所示。

Function map(key ii , dataset DiDi )

  1. Read (xi,ui,^z)(xi,ui,z^) from distributed database

  2. Compute z:=argminz(g(z)+(Nρ/2)|z^z/N|22)z:=arg⁡minz(g(z)+(Nρ/2)|z−z^/N|22)

  3. Update ui:=ui+xizui:=ui+xi−z

  4. Update xi:=argminx(fi(x)+(ρ/2)|xz+ui|22)xi:=arg⁡minx(fi(x)+(ρ/2)|x−z+ui|22)

  5. Emit (key CENTRALCENTRAL , record (xi,ui)(xi,ui) )

EndFunction

Function reduce (key CENTRALCENTRAL , records (x1,u1),,(xN,uN)(x1,u1),…,(xN,uN) )

  1. Update ^z:=Ni=1(xi+ui)z^:=∑i=1N(xi+ui)

  2. Emit (key jj , record (xj,uj,z)(xj,uj,z) ) to distributed database for j=1,,Nj=1,…,N

EndFunction

为了实现多次迭代,该算法需要由一个 wrapper 程序在每次迭代结束后判断是否满足迭代终止条件 ρNzzprev2ϵconvρN∥z−zprev∥2≤ϵconv 且 (Ni=1xiz22)1/2ϵfeas(∑i=1N∥xi−z∥22)1/2≤ϵfeas ,若不满足则启动下一次迭代。

参考文献

  • Boyd S, Parikh N, Chu E, et al. Distributed optimization and statistical learning via the alternating direction method of multipliers[J]. Foundations and Trends® in Machine Learning, 2011, 3(1): 1-122.

  • Eckstein J, Yao W. Understanding the convergence of the alternating direction method of multipliers: Theoretical and computational perspectives[J]. Pac. J. Optim., 2014.

  • Lusk E, Huss S, Saphir B, et al. MPI: A Message-Passing Interface Standard Version 3.1[J], 2015.

  • Dean J, Ghemawat S. MapReduce: simplified data processing on large clusters[J]. Communications of the ACM, 2008, 51(1): 107-113.

你可能感兴趣的:(Mathematics)