交替方向乘子法(Alternating Direction Method of Multipliers,ADMM)是一种解决可分解凸优化问题的简单方法,尤其在解决大规模问题上卓有成效,利用ADMM算法可以将原问题的目标函数等价的分解成若干个可求解的子问题,然后并行求解每一个子问题,最后协调子问题的解得到原问题的全局解。ADMM 最早分别由 Glowinski & Marrocco 及 Gabay & Mercier 于 1975 年和 1976 年提出,并被 Boyd 等人于 2011 年重新综述并证明其适用于大规模分布式优化问题。由于 ADMM 的提出早于大规模分布式计算系统和大规模优化问题的出现,所以在 2011 年以前,这种方法并不广为人知。
考虑等式约束的最优化问题如下 min x f ( x ) \min_x f(x) xminf(x) s t . A x = b st. Ax=b st.Ax=b其中 x ∈ R n , A ∈ R m × n , f : R n ∈ R x\in\mathbb R^n,A\in\mathbb R^{m\times n},f:\mathbb R^n \in\mathbb R x∈Rn,A∈Rm×n,f:Rn∈R是凸函数
原问题的拉格朗日函数为: L ( x , y ) = f ( x ) + y T ( A x − b ) L(x,y) = f(x) + y^T(Ax-b) L(x,y)=f(x)+yT(Ax−b)那么其对偶函数为: g ( y ) = i n f x L ( x , y ) = − f ∗ ( − A T y ) − b T y g(y)=inf_xL(x,y)=-f^*(-A^Ty)-b^Ty g(y)=infxL(x,y)=−f∗(−ATy)−bTy其中 y y y是拉格朗日乘子,也是对偶变量, f ∗ f^* f∗是 f f f共轭函数。
假设满足强对偶性,则原问题和对偶问题的最优值相等。我们设原问题最优解为 x ∗ x^* x∗,对偶问题最优解为 y ∗ y^* y∗,则
x ∗ = a r g m i n x L ( x , y ∗ ) x^*=argmin_xL(x,y^*) x∗=argminxL(x,y∗)在对偶上升方法中,对偶问题是通过梯度上升方法来解,因此对偶上升迭代更新为: x k + 1 = a r g m i n x L ( x , y k ) x_{k+1}=argmin_xL(x,y^k) xk+1=argminxL(x,yk) y k + 1 = y k + α k ( A x k + 1 − b ) y^{k+1}=y^k+\alpha_k(Ax^k+1-b) yk+1=yk+αk(Axk+1−b)其中 α k > 0 \alpha_k>0 αk>0,是梯度上升的步长。
对偶上升方法中在满足强对偶性条件下,通过梯度上升来逐步调整对偶变量,再通过对偶变量来求解原问题最优解,这样的好处是在有些情况下可以使算法可分解,假设目标函数是可分解的,即,
f ( x ) = ∑ i = 1 N f i ( x i ) f(x)=\sum_{i=1}^Nf_i(x_i) f(x)=i=1∑Nfi(xi)其中 x = ( x 1 , x 2 , … , x N ) , x i ∈ R n i x=(x_1,x_2,\dots,x_N),x_i\in\mathbb R^{n_i} x=(x1,x2,…,xN),xi∈Rni,划分矩阵A A = [ A 1 , A 2 , ⋯   , A N ] A=[A_1,A_2,\cdots,A_N] A=[A1,A2,⋯,AN]所以 A x = ∑ i = 1 N A i x i Ax=\sum_{i=1}^NA_ix_i Ax=∑i=1NAixi,则拉格函数重写成 L ( x , y ) = ∑ i = 1 N L i ( x i , y ) = ∑ i = 1 N ( f i ( x i ) + y T A i x i − ( 1 N ) y T b ) L(x,y)=\sum_{i=1}^NL_i(x_i,y)=\sum_{i=1}^N(f_i(x_i)+y^TA_ix_i-(\frac{1}{ N})y^Tb) L(x,y)=i=1∑NLi(xi,y)=i=1∑N(fi(xi)+yTAixi−(N1)yTb)对偶上升的迭代更新: x i k + 1 = a r g m i n x i L i ( x i , y k ) x_i^{k+1}=argmin_{x_i}L_i(x_i,y^k) xik+1=argminxiLi(xi,yk) y k + 1 = y k + α k ( A x k + 1 − b ) y^{k+1}=y^k+\alpha^k(Ax^{k+1}-b) yk+1=yk+αk(Axk+1−b)
为了增加对偶上升方法的鲁棒性和放松目标强凸的要求,引入了增广拉格朗日,增广拉格朗日,形式如下: L ρ ( x , y ) = f ( x ) + y T ( A x − b ) + ρ 2 ∥ A x − b ∥ 2 2 L_\rho(x,y)=f(x)+y^T(Ax-b)+\frac{\rho}{2}\begin{Vmatrix} Ax-b \end{Vmatrix}_2^2 Lρ(x,y)=f(x)+yT(Ax−b)+2ρ∥∥Ax−b∥∥22其中 ρ > 0 \rho >0 ρ>0,是惩罚参数。
增广拉格朗日相当于在拉格朗日加了一个强凸的惩罚项使得原约束问题变成如下形式 min x f ( x ) + ρ 2 ∥ A x − b ∥ 2 2 \min_xf(x)+\frac{\rho}{2}\begin{Vmatrix} Ax-b\end{Vmatrix}_2^2 xminf(x)+2ρ∥∥Ax−b∥∥22 s . t . A x = b s.t. Ax=b s.t.Ax=b对偶上升迭代更新: x k + 1 = a r g m i n x L ρ ( x , y k ) x^{k+1}=argmin_xL_\rho(x,y^k) xk+1=argminxLρ(x,yk) y k + 1 = y k + ρ ( A x k + 1 − b ) y^{k+1}=y^k+\rho(Ax^{k+1}-b) yk+1=yk+ρ(Axk+1−b)这就是所谓的乘子法,从迭代形式上看,有两点和对偶上升方法不一样,第一是 x k + 1 x^{k+1} xk+1更新使用增广拉格朗日,第二是更新的步长用惩罚参数 ρ \rho ρ代替了 α k \alpha^k αk。乘子法虽然可以比对偶上升方法在更一般的条件下收敛,但是增加了二次惩罚项,使得原问题失去可分解性,因而引入了ADMM算法。
假设有如下优化问题: min x f ( x ) + g ( z ) \min_xf(x)+g(z) xminf(x)+g(z) s . t . A x + B z = c s.t. Ax+Bz=c s.t.Ax+Bz=c其中 x ∈ R n , z ∈ R m , A ∈ R p × n , B ∈ R p × m , c ∈ R p x\in\mathbb R^n,z\in\mathbb R^m,A\in\mathbb R^{p\times n},B\in\mathbb R^{p\times m},c\in\mathbb R^p x∈Rn,z∈Rm,A∈Rp×n,B∈Rp×m,c∈Rp
其增广拉格朗日函数如下: L ρ ( x , z , y ) = f ( x ) + g ( z ) + y T ( A x + B z − c ) + ρ 2 ∥ A x + B z − c ∥ 2 2 L_\rho(x,z,y)=f(x)+g(z)+y^T(Ax+Bz-c)+\frac{\rho}{2}\begin{Vmatrix} Ax+Bz-c\end{Vmatrix}_2^2 Lρ(x,z,y)=f(x)+g(z)+yT(Ax+Bz−c)+2ρ∥∥Ax+Bz−c∥∥22ADMM的更新迭代形式如下: x k + 1 = a r g m i n x L ρ ( x , y k , z k ) x^{k+1}=argmin_xL_\rho(x,y^k,z^k) xk+1=argminxLρ(x,yk,zk) z k + 1 = a r g m i n z L ρ ( x k + 1 , z , y k ) z^{k+1}=argmin_zL_\rho(x^{k+1},z,y^k) zk+1=argminzLρ(xk+1,z,yk) y k + 1 = y k + ρ ( A x k + 1 + B z k + 1 − c ) y^{k+1}=y^k+\rho(Ax^{k+1}+Bz^{k+1}-c) yk+1=yk+ρ(Axk+1+Bzk+1−c)其中 ρ > 0 \rho > 0 ρ>0,以上就是最原始的ADMM算法迭代形式
接下来讨论如何将其应用到机器学习算法分布式计算中,但在这之前我们先讨论一个比较广义的优化问题如下: min x ∑ i = 1 N f i ( x i ) \min_x\sum_{i=1}^Nf_i(x_i) xmini=1∑Nfi(xi) s . t . x i − z = 0 , i = 0 , 1 , … , N s.t.x_i-z=0,i=0,1,\dots,N s.t.xi−z=0,i=0,1,…,N以上问题形式就是全局一致性问题, x i x_i xi局部变量, z z z是全局一致变量。从限制约束条件来看,每一个局部变量需要保持一致。
其增广拉格朗日函数为: L ρ ( x 1 , x 2 , … , x N , z , y ) = ∑ i = 1 N f i ( x i ) + y i T ( x i − z ) + ρ 2 ∥ x i − z ∥ 2 2 L_\rho(x_1,x_2,\dots,x_N,z,y)=\sum_{i=1}^Nf_i(x_i)+y^T_i(x_i-z)+\frac{\rho}{2}\begin{Vmatrix} x_i-z\end{Vmatrix}_2^2 Lρ(x1,x2,…,xN,z,y)=i=1∑Nfi(xi)+yiT(xi−z)+2ρ∥∥xi−z∥∥22ADMM迭代更新形式: x i k + 1 = a r g m i n x f i ( x i ) + y i T k ( x i − z k ) + ρ 2 ∥ x i − z k ∥ 2 2 x_i^{k+1}=argmin_xf_i(x_i)+y_i^{T^k}(x_i-z^k)+\frac{\rho}{2}\begin{Vmatrix} x_i-z^k\end{Vmatrix}_2^2 xik+1=argminxfi(xi)+yiTk(xi−zk)+2ρ∥∥xi−zk∥∥22 z k + 1 = 1 N ∑ i = 1 N ( x i k + 1 + 1 ρ y i k ) z^{k+1}=\frac{1}{N}\sum_{i=1}^N(x_i^{k+1}+\frac{1}{\rho}y_i^k) zk+1=N1i=1∑N(xik+1+ρ1yik) y i k + 1 = y i k + ρ ( x i k + 1 − z k + 1 ) y_i^{k+1}=y_i^k+\rho(x_i^{k+1}-z^{k+1}) yik+1=yik+ρ(xik+1−zk+1)
我们可以更进一步将上述问题推广更广义的优化问题,带正则的全局一致性优化问题,具体问题如下: min x ∑ i = 1 N f i ( x i ) + g ( z ) \min_x\sum_{i=1}^Nf_i(x_i)+g(z) xmini=1∑Nfi(xi)+g(z) s . t . x i − z = 0 , i = 0 , 1 , … , N s.t. x_i-z=0,i=0,1,\dots,N s.t.xi−z=0,i=0,1,…,N其ADMM迭代更新形式如下: x i k + 1 = a r g m i n x ( f i ( x i ) + y i T k ( x i − z k ) + ρ 2 ∥ x i − z k ∥ 2 2 ) x_i^{k+1}=argmin_x(f_i(x_i)+y_i^{T^k}(x_i-z^k)+\frac{\rho}{2}\begin{Vmatrix} x_i-z^k\end{Vmatrix}_2^2) xik+1=argminx(fi(xi)+yiTk(xi−zk)+2ρ∥∥xi−zk∥∥22) z k + 1 = a r g m i n z ( g ( z ) + ∑ i = 1 N ( − y i T k + ρ 2 ∥ x i k + 1 − z ∥ 2 2 ) ) z^{k+1}=argmin_z(g(z)+\sum_{i=1}^N(-y_i^{T^k}+\frac{\rho}{2}\begin{Vmatrix} x_i^{k+1}-z\end{Vmatrix}_2^2)) zk+1=argminz(g(z)+i=1∑N(−yiTk+2ρ∥∥xik+1−z∥∥22)) y i k + 1 = y i k + ρ ( x i k + 1 − z i k + 1 ) y_i^{k+1}=y_i^k+\rho(x_i^{k+1}-z_i^{k+1}) yik+1=yik+ρ(xik+1−zik+1)
##ADMM算法的应用
以上讨论了一些的理论的东西,接下来就尝试一下如何去应用ADMM算法去解决一些机器学习问题。实际上大多数分布式机器学习问题可以表达成全局一致性优化问题: min x ∑ i = 1 N f i ( x i ) + g ( z ) \min_x\sum_{i=1}^Nf_i(x_i)+g(z) xmini=1∑Nfi(xi)+g(z) s . t . x i − z = 0 , i = 0 , 1 , … , N , s.t.x_i-z=0,i=0,1,\dots,N, s.t.xi−z=0,i=0,1,…,N, f i ( x i ) f_i(x_i) fi(xi)表示划分到每一个节点上的目函数, x i x_i xi表示局部模型参数, z z z表示全局一致性变量。我们把整个数据集划分到每一节点,各节点独立并行训练,通过迭代更新,最终收敛到一个一致的全局解。
##Lasso问题
具体问题形式如下: min x 1 2 ∥ A x − b ∥ 2 2 + λ ∥ x ∥ 1 \min_x\frac{1}{2}\begin{Vmatrix} Ax-b\end{Vmatrix}_2^2+\lambda\begin{Vmatrix} x\end{Vmatrix}_1 xmin21∥∥Ax−b∥∥22+λ∥∥x∥∥1其中, A ∈ R m × n , b ∈ R m A\in\mathbb R^{m\times n},b\in\mathbb R^m A∈Rm×n,b∈Rm.
更换成全局一致性优化问题: min x i ∑ i = 1 N 1 2 ∥ A x − b ∥ 2 + λ ∥ z ∥ 1 \min_{x_i}\sum_{i=1}^N\frac{1}{2}\begin{Vmatrix} Ax-b\end{Vmatrix}^2+\lambda\begin{Vmatrix} z\end{Vmatrix}_1 ximini=1∑N21∥∥Ax−b∥∥2+λ∥∥z∥∥1 s . t . x i − z = 0 , i = 0 , 1 , … , N s.t. x_i-z=0,i=0,1,\dots,N s.t.xi−z=0,i=0,1,…,NADMM迭代更新形式如下: x i k + 1 = a r g m i n x i ( 1 2 ∥ A i x i − b i ∥ 2 + y i T k ( x i − z k ) + ρ 2 ∥ x i − z k ∥ 2 2 ) x_i^{k+1}=argmin_{x_i}(\frac{1}{2}\begin{Vmatrix} A_ix_i-b_i\end{Vmatrix}^2+y_i^{T^k}(x_i-z^k)+\frac{\rho}{2}\begin{Vmatrix} x_i-z^k\end{Vmatrix}_2^2) xik+1=argminxi(21∥∥Aixi−bi∥∥2+yiTk(xi−zk)+2ρ∥∥xi−zk∥∥22) z k + 1 = a r g m i n z ( λ ∥ z ∥ 1 + ∑ i = 1 N ( − y i T k z + ρ 2 ∥ x i k + 1 − z ∥ 2 2 ) ) z^{k+1}=argmin_z(\lambda\begin{Vmatrix} z\end{Vmatrix}_1+\sum_{i=1}^N(-y_i^{T^k}z+\frac{\rho}{2}\begin{Vmatrix} x_i^{k+1}-z\end{Vmatrix}_2^2)) zk+1=argminz(λ∥∥z∥∥1+i=1∑N(−yiTkz+2ρ∥∥xik+1−z∥∥22)) y i k + 1 = y i k + ρ ( x i k + 1 − z k + 1 ) y_i^{k+1}=y_i^k+\rho(x_i^{k+1}-z^{k+1}) yik+1=yik+ρ(xik+1−zk+1)
对于 x k + 1 x^{k+1} xk+1的更新,右边括号内的函数对 x x x求可导得(这里可以有更高效的算法求解): A i T ( A i x i − b i ) + y i k + ρ ( x i − z k ) = 0 A_i^T(A_ix_i-b_i)+y_i^k+\rho(x_i-z^k)=0 AiT(Aixi−bi)+yik+ρ(xi−zk)=0即可得: x i k + 1 = ( A i T A i + ρ I ) − 1 ( A i T b i + ρ z k − y i k ) x_i^{k+1}=(A_i^TA_i+\rho I)^{-1}(A_i^Tb_i+\rho z^k-y_i^k) xik+1=(AiTAi+ρI)−1(AiTbi+ρzk−yik)对于 z z z的更新如下: z k + 1 = a r g m i n z ( λ ∥ z ∥ 1 + ρ N 2 ∥ z − 1 ρ N ∑ i = 1 N ( ρ x i k + 1 + y i k ) ∥ 2 2 ) z^{k+1}=argmin_z(\lambda\begin{Vmatrix} z\end{Vmatrix}_1+\frac{\rho N}{2}\begin{Vmatrix} z-\frac{1}{\rho N}\sum_{i=1}^N(\rho x_i^{k+1}+y_i^k)\end{Vmatrix}_2^2) zk+1=argminz(λ∥∥z∥∥1+2ρN∥∥∥z−ρN1∑i=1N(ρxik+1+yik)∥∥∥22)因此 z z z的更新可以使用softthreshold方法如下所示: z k + 1 = S λ / ρ N ( 1 ρ N ∑ i = 1 N ( ρ x i k + 1 + y i k ) ) z^{k+1}=\mathcal S_{\lambda/\rho N}(\frac{1}{\rho N}\sum_{i=1}^N(\rho x_i^{k+1}+y_i^k)) zk+1=Sλ/ρN(ρN1i=1∑N(ρxik+1+yik))其中 S κ ( a ) = { a − κ , a > κ 0 , ∣ a ∣ ≤ κ a + κ , a < − κ \mathcal S_{\kappa}(a)= \begin{cases} a-\kappa, & \text { $a>\kappa$ } \\ 0, & \text{ $\mid a\mid\leq\kappa$} \\a+\kappa , & \text{$a<-\kappa$}\end{cases} Sκ(a)=⎩⎪⎨⎪⎧a−κ,0,a+κ, a>κ ∣a∣≤κa<−κ