import numpy as np import matplotlib.pyplot as plt import matplotlib.pylab as mpl from mpl_toolkits.mplot3d import Axes3D # 画图中文显示:不报错 mpl.rcParams['font.sans-serif'] = [u'simHei'] mpl.rcParams['axes.unicode_minus'] = False # 设置二维函数(有2个特征属性:x1, x2) def f(x1, x2): return 0.6 * (x1 + x2) ** 2 - x1 * x2 # 2个特征属性的求导得theta1, theta2 def hx1(x1, x2): return 1.2 * (x1 + x2) - x2 def hx2(x1, x2): return 1.2 * (x1 + x2) - x1 # 使用梯度下降法求解 GD_x1 = [] GD_x2 = [] GD_Y = [] # 设置初始化参数 x1 = 7 x2 = 6 alpha = 0.5 f_change = f(x1, x2) f_current = f_change GD_x1.append(x1) GD_x2.append(x2) GD_Y.append(f_current) # 设置迭代次数统计变量 iter_num = 0 # 设置迭代条件 while iter_num < 500 and f_change > 1e-10: iter_num += 1 pre_x1 = x1 pre_x2 = x2 x1 = pre_x1 - alpha * hx1(pre_x1, pre_x2) x2 = pre_x2 - alpha * hx2(pre_x1, pre_x2) tmp = f(x1, x2) # 梯度差值的绝对值 f_change = np.abs(f_current - tmp) f_current = tmp GD_x1.append(x1) GD_x2.append(x2) GD_Y.append(f_current) print(u'最终结果为:(%.5f, %.5f, %.5f)' % (x1, x2, f_current)) print(u'最终迭代次数结果为:%d' % iter_num) print(GD_Y) # print() # 构建原函数数据 X1 = np.arange(-8, 8.5, 0.1) X2 = np.arange(-8, 8.5, 0.1) X1, X2 = np.meshgrid(X1, X2) # 绘制X1, X2网格点 Y = np.array(list(map(lambda t: f(t[0], t[1]), zip(X1.flatten(), X2.flatten())))) Y.shape = X1.shape # 准备3D画布 fig = plt.figure(facecolor='w') ax = Axes3D(fig) ax.plot_surface(X1, X2, Y, rstride=1, cstride=1, cmap=plt.cm.jet) ax.plot(GD_x1, GD_x2, GD_Y, 'ko--') ax.set_title(u'函数$0.6 * (x1 + x2) ^ 2 - x1 * x2$: \n学习率:%.3f; 最终解是:(%.3f, %.3f, %.3f); 迭代次数:%d' % (alpha, x1, x2, f_current, iter_num)) plt.show()
E:\myprogram\anaconda\python.exe E:/xx/机器学习/梯度下降操作/梯度下降算法3维图像示例.py
最终结果为:(0.00000, -0.00000, 0.00000)
最终迭代次数结果为:17
[59.39999999999999, 5.386000000000001, 0.4947399999999999, 0.04702659999999999, 0.004857394, 0.0005934154600000001, 9.246989140000004e-05, 1.8087915226000006e-05, 4.06931862034e-06, 9.765902383306002e-07, 2.40481012074754e-07, 5.979026374297786e-08, 1.491786690093051e-08, 3.726793812099371e-09, 9.314578908428496e-10, 2.3284282211433304e-10, 5.820875697490912e-11, 1.4552013873896607e-11]