
import numpy as np
def generate_noise_data(num=1000, noise_sigma=1.0):
x = np.linspace(0, 1, num)
a, b, c = 1, 2, 3
func = lambda x: np.exp(a * x**2 + b * x + c)
y = func(x) + np.random.normal(0, noise_sigma, num)
return x, y
def gauss_newton():
x, y = generate_noise_data()
a, b, c = 1, 1, 1
detla_a = lambda a, b, c: - x**2 * np.exp(a * x**2 + b * x + c)
detla_b = lambda a, b, c: - x * np.exp(a * x**2 + b * x + c)
detla_c = lambda a, b, c: - np.exp(a * x**2 + b * x + c)
func = lambda a, b, c: np.exp(a * x**2 + b * x + c)
for iter in range(101):
error = y - func(a, b, c)
cost = np.sum(error ** 2)
J = np.vstack([detla_a(a, b, c), detla_b(a, b, c), detla_c(a, b, c)])
H = J @ J.T
bb = np.sum(-J * error, axis=1).reshape(-1, 1)
delta = np.linalg.inv(H) @ bb
a += delta[0, 0]
b += delta[1, 0]
c += delta[2, 0]
if iter % 10 == 0:
print(f"iter: {iter}, cost: {cost}, a: {a}, b: {b}, c: {c}")
if __name__ == "__main__":
gauss_newton()
'''
iter: 0, cost: 19451883.170121048, a: 7.778319737506763, b: 6.646686710088488, c: 7.615218392754173
iter: 10, cost: 4952005071479.72, a: 8.788212680472725, b: 4.652260498746273, c: -1.3984130105756094
iter: 20, cost: 113097.87399971182, a: 1.2484731861459692, b: 1.6092037161265653, c: 3.1479602802393742
iter: 30, cost: 930.2174621037896, a: 1.0005171031754931, b: 1.9993132909257523, c: 3.0000523922709754
iter: 40, cost: 930.2174621037897, a: 1.0005171031754931, b: 1.9993132909257525, c: 3.0000523922709754
iter: 50, cost: 930.2174621037914, a: 1.0005171031754936, b: 1.9993132909257518, c: 3.0000523922709754
iter: 60, cost: 930.2174621037896, a: 1.0005171031754931, b: 1.9993132909257523, c: 3.0000523922709754
iter: 70, cost: 930.2174621037897, a: 1.0005171031754931, b: 1.9993132909257525, c: 3.0000523922709754
iter: 80, cost: 930.2174621037914, a: 1.0005171031754936, b: 1.9993132909257518, c: 3.0000523922709754
iter: 90, cost: 930.2174621037896, a: 1.0005171031754931, b: 1.9993132909257523, c: 3.0000523922709754
iter: 100, cost: 930.2174621037897, a: 1.0005171031754931, b: 1.9993132909257525, c: 3.0000523922709754
'''