拟牛顿法是在牛顿法的基础上引入了Hessian矩阵的近似矩阵,避免每次迭代都计算Hessian矩阵的逆,它的收敛速度介于梯度下降法和牛顿法之间。拟牛顿法跟牛顿法一样,也是不能处理太大规模的数据,因为计算量和存储空间会开销很多。
拟牛顿法虽然每次迭代不像牛顿法那样保证是最优化的方向,但是近似矩阵始终是正定的,因此算法始终是朝着最优化的方向在搜索。具有全局收敛性和超线性收敛速度
BFGS(Broyden,Fletcher,Goldfarb,Shanno四个人)算法是使用较多的一种拟牛顿方法,故称为BFGS校正。
将 x 写成 x=(x1,x2,…,xn) 。对函数 f(x) 在 x=xk+1 处进行泰勒展开到二阶:
令 Ek=αukuTk+βvkvTk ,其中 uk,vk 均为 n∗1 的向量。 yk=∇f(xk+1)−∇f(xk),sk=xk+1−xk .
那么 Bk+1(xk−xk+1)=∇f(xk+1)−∇f(xk)
可以化简为:
即:
uTksk,vTksk 皆为实数, yk−Bksk 为 n∗1 的向量,上式中,参数 α 和 β 解的可能性有很多,我们取特殊的情况,假设 uk=rBksk,vk=θyk 。则:
设 Bk 对称正定, Bk+1 由上述的BFGS校正公式确定,那么 Bk+1 对称正定的充要条件是 yTksk>0 。
非精确的一维搜索(线搜索)准则:Armijo搜索准则,搜索准则的目的是为了帮助我们确定学习率,还有其他的一些准则,如Wolfe准则以及精确线搜索等。在利用Armijo搜索准则时并不是都满足上述的充要条件,此时可以对BFGS校正公式做些许改变:
注:在李航写的那本《统计学习方法》中说是正定的,但是并没有说上述情况下会怎么样
求解无约束优化问题:
#coding:UTF-8
'''
Created on 2017年4月20日
@author: zhangdapeng
'''
from numpy import *
import matplotlib.pyplot as plt
from numpy.matrixlib.defmatrix import mat
#fun 原始函数
def fun(x):
return 100 * (x[0,0] ** 2 - x[1,0]) ** 2 + (x[0,0] - 1) ** 2
#对x1,x2求导后的函数
def gfun(x):
result = zeros((2, 1))
# 对x1求导
result[0, 0] = 400 * x[0,0] * (x[0,0] ** 2 - x[1,0]) + 2 * (x[0,0] - 1)
result[1, 0] = -200 * (x[0,0] ** 2 - x[1,0]) #对x2求导
return result
def bfgs(fun, gfun, x0):
result = []
maxk = 500
delta = 0.55
sigma = 0.4
m = shape(x0)[0]
Bk = eye(m)
k = 0
epsilon=1e-10
while (k < maxk):
gk = mat(gfun(x0))#计算梯度 ,mat函数将数组转化为矩阵。
# print(gk)
# print(linalg.norm(gk,1))
#axis=0,沿着纵轴方向
if linalg.norm(gk,1)break
dk = mat(-linalg.solve(Bk, gk)) #解矩阵方程Bk*x=gk得到x
m = 0
mk = 0
while (m < 20):
newf = fun(x0 + delta ** m * dk)
oldf = fun(x0)
if (newf < oldf + sigma * (delta ** m) * (gk.T * dk)[0,0]):
mk = m
break
m = m + 1
#BFGS校正
x = x0 + delta ** mk * dk
sk = x - x0
yk = gfun(x) - gk
# print(math.isnan(yk.T * sk))
if (yk.T * sk > 0):
Bk = Bk - (Bk * sk * sk.T * Bk) / (sk.T * Bk * sk) + (yk * yk.T) / (yk.T * sk)
k = k + 1
x0 = x
result.append(fun(x0))
return result
#初始化x0
x0 = mat([[-1.2], [1]])
result = bfgs(fun, gfun, x0)
print("result:",result[-1])
n = len(result)
ax = plt.figure().add_subplot(111)
x = arange(0, n, 1)
y = result
ax.plot(x,y)
plt.show()
result: 2.68262011582e-28
http://blog.csdn.net/google19890102/article/details/45867789
http://blog.csdn.net/acdreamers/article/details/44664941
http://www.codelast.com/%E5%8E%9F%E5%88%9B%E7%94%A8%E4%BA%BA%E8%AF%9D%E8%A7%A3%E9%87%8A%E4%B8%8D%E7%B2%BE%E7%A1%AE%E7%BA%BF%E6%90%9C%E7%B4%A2%E4%B8%AD%E7%9A%84armijo-goldstein%E5%87%86%E5%88%99%E5%8F%8Awo/