最速下降优化算法

最速下降优化算法

1、课后作业

最速下降优化算法_第1张图片

2、前情提要

最速下降优化算法_第2张图片
最速下降优化算法_第3张图片

最速下降优化算法_第4张图片
最速下降优化算法_第5张图片

手写求解

最速下降优化算法_第6张图片
最速下降优化算法_第7张图片
最速下降优化算法_第8张图片

代码求解

#最速下降优化算法
from sympy import*
import sympy
import math
from matplotlib import pyplot as plt
import numpy as np

def Obj(x1,x2):
    value=x1-x2+2*math.pow(x1,2)+2*x1*x2+math.pow(x2,2)
    return value
def Jac_x1(x1,x2):
    value=1+4*x1+2*x2
    return value
def Jac_x2(x1,x2):
    value=-1+2*x1+2*x2
    return value

def Error(x1,x2):
    value=math.sqrt(math.pow(x1,2)+math.pow(x2,2))
    return value


#指定初始值
Epsilon=0.01
#设定误差的初始值
X=[0,0]#x0
Direction=[-Jac_x1(X[0],X[1]),-Jac_x2(X[0],X[1])]#指定负梯度方向作为下降方向
t=sympy.symbols('t') # t是我们的步长
Temp_x=[X[0]+t*Direction[0],X[1]+t*Direction[1]]
z=diff(Temp_x[0]-Temp_x[1]+2*Temp_x[0]*Temp_x[0]+2*Temp_x[0]*Temp_x[1]+Temp_x[1]*Temp_x[1])
Lambda=sympy.solve(z) #求解最优步长,该步长使我们的函数下降最快
Current_error=Error(Jac_x1(X[0],X[1]),Jac_x2(X[0],X[1]))

#算法迭代部分
Value=[]
Minimum=Obj(X[0],X[1])
error=[]
lamb=[]
while Current_error > Epsilon:
    print("当前错误值:{}".format(Current_error))
    #保存数据
    error.append(Current_error)
    Value.append(Obj(X[0],X[1]))
    lamb.append(Lambda)
    # if Minimum>Obj(X[0],X[1]):
    #     Minimum=Obj(X[0],X[1])
    #     if(Minimum==-1.25):
    #         break
    #print(Minimum)
    #更新相关数据 
    temp1=X[0]+Lambda[0]*Direction[0]
    temp2=X[1]+Lambda[0]*Direction[1]
    X[0]=temp1
    X[1]=temp2
    #X=[X[0]+Lambda*Direction[0],X[1]+Lambda*Direction[1]]#算法核心迭代部分,更新点
    Direction=[-Jac_x1(X[0],X[1]),-Jac_x2(X[0],X[1])] #更新方向向量
    Temp_x=[X[0]+t*Direction[0],X[1]+t*Direction[1]]
    z=diff(Temp_x[0]-Temp_x[1]+2*Temp_x[0]*Temp_x[0]+2*Temp_x[0]*Temp_x[1]+Temp_x[1]*Temp_x[1])
    Current_error=Error(Jac_x1(X[0],X[1]),Jac_x2(X[0],X[1]))
    Lambda=sympy.solve(z) #更新Lambda
print("最小值是:{}".format(Obj(X[0],X[1])))
print("迭代次数是:{}".format(len(Value)))
print("最小值点是:",X)

# plt.plot(Value)
# plt.title("x1 x2坐标")
# plt.xlabel("The number of iteration")
# plt.ylabel("The value of function")
# fig=plt.figure()
# plt.plot(error)
# plt.title("The changes of error")
# plt.xlabel("The number of iteration")
# plt.ylabel("The value of error")

# plt.show()

结果显示

当前错误值:1.4142135623730951
当前错误值:1.4142135623730951
当前错误值:0.28284271247461906
当前错误值:0.28284271247461906
当前错误值:0.0565685424949238
当前错误值:0.0565685424949238
当前错误值:0.01131370849898476
当前错误值:0.01131370849898476
最小值是:-1.24999680000000
迭代次数是:8
最小值点是: [-624/625, 936/625]

与手算结果一致

你可能感兴趣的:(算法,vscode,matplotlib,python)