牛顿迭代(二元函数)

import numpy as np
import matplotlib.pyplot as plt
from sympy import *

# 定义符号
x1, x2 = symbols('x1, x2')
# 定义所求函数
f = 0.2*x1**2 + x2**2


# 求解梯度值
def get_grad(f, X):
    # 计算一阶导数
    f1 = diff(f, x1)
    f2 = diff(f, x2)
    X = X.tolist()
    X1 = X[0][0]
    X2 = X[1][0]
    # 代入具体数值计算
    grad = np.array([ [f1.subs({x1: X1, x2: X2}).evalf()],
                      [f2.subs({x1: X1, x2: X2}).evalf()] ])
    return grad


# 求解Hession矩阵
def get_hess(f, X):
    # 计算二次偏导
    f1 = diff(f, x1)
    f2 = diff(f, x2)
    f11 = diff(f,x1,2)
    f22 = diff(f,x2,2)
    f12 = diff(f1,x2)
    f21 = diff(f2,x1)
    # 计算具体数值计算
    hess = np.array([[f11.subs([(x1,X[0]), (x2,X[1])]),
                        f12.subs([(x1,X[0]), (x2,X[1])])],

                        [f21.subs([(x1,X[0]), (x2,X[1])]),
                        f22.subs([(x1,X[0]), (x2,X[1])])]])
    # 转换数值类型为了后续求逆矩阵
    hess = np.array(hess, dtype = 'float')
    return hess

# 牛顿迭代
def newton_iter(X, epsilon, max_iter):
    print('初始值','x1=',X[0][0],'x2=',X[1][0])
    count = 0
    while count< max_iter:
        grad = get_grad(f, X)
        grad_1_value = grad[0][0]
        grad_1_value = grad[1][0]
        if abs(grad_1_value) + abs(grad_1_value) >= epsilon:
            hess = get_hess(f, X0)
            # 得到Hession矩阵的逆
            hess_inv = np.linalg.inv(hess)
            # 牛顿迭代公式!!!!
            X = X - np.dot(hess_inv, grad)
            count += 1
            print('第',count,'次迭代:','x1=',X[0][0],'x2=',X[1][0])
        else:
            break
    print('迭代次数为:',count)
# 设置初始点
X0 = np.array([[1],[1]])
epsilon = 0.00001
max_iter = 50
newton_iter(X0, epsilon, max_iter)

def f(x, y):
    return 0.2*(x**2)+y**2

X = np.array([1, 0])
Y = np.array([1, 0])
x_values = np.linspace(-1, 1, 500)
y_values = np.linspace(-1, 1, 500)

xx, yy = np.meshgrid(x_values, y_values)

fig = plt.figure(figsize=(8, 8))
contour_line = plt.contour(xx, yy, f(xx, yy), levels=20, cmap=plt.cm.gray)
plt.clabel(contour_line, inline=1, fontsize=10)
plt.plot(X, Y, color='r', linestyle='-', marker='*', linewidth =2.0)
plt.show()

你可能感兴趣的:(python,python,numpy)