下降单纯形法(downhill simplex method)是一个广泛使用的“derivative free”的优化算法。一般来说它的效率不高,但是文献[1]提到“the downhill simplex method may frequently be the *best* method to use if the figure of merit is “get something working quickly” for a problem whose computational burden is small.”
- 首先计算最差点沿着直线$\bar{\mathbf{x}}(t)$关于平均点$\bar{\mathbf{x}}$的对称点
- 如果对称点介于最好和次差点之间,那么就接受它(reflection);
- 如果对称点比最好点还好,那么做沿该方向更大胆的尝试,令\(t=-2\),如果新尝试点比对称点更好则接受新尝试点(expand),否则接受当前对称点(reflection);
- 如果对称点介于次差点和最差点之间,那么沿该方向做更小心的尝试,即令\(t=-0.5\),如果新尝试点比对称点更好则接受新尝试点(outside contraction)
- 如果对称点比最差点还差,那么沿反方向做尝试,即令\(t=0.5\),如果新尝试点比对称点更好则接受新尝试点(inside contraction)
- 如果4和5均失败,即对称点比次差还要差而且outside contraction与inside contraction均失败,那么把最好点之外的其他点都朝最好点收缩(shrink)
上述过程如果用区间图表示会更清晰,区间的三个分界点就是最好、 次差和最差点。对应的伪代码可以参考文献 [3] 第9.5节。
另外,在实际应用中还有一个很重要的诀窍,即重启(restart)。因为单纯形在迭代更新的时候很容易就卡在某个中间位置上,这时单纯形的 *最好* 和 *最差* 点几乎相同,单纯形的体积收缩的很小,会大大减慢迭代速度。为了解决这个问题,可以合理设置初始单纯形的大小。更有效的,就是可以在单纯形卡住的时候通过重新初始化单纯形来加快收敛速度。在利用初始化公式的时候,把当前单纯形的 *最好* 点作为\(\mathbf{x}_{0}\)保留下来,这样保证重启就不会影响之前已经计算的结果。
import numpy as np from matplotlib import pyplot as plt import seaborn as sns def vertice_init(vertex_0, step_length): ''' initialize vertice of the simplex using the following formula: $xi=x0+step_length*ei$ ''' emat = np.eye(vertex_0.size) * step_length vertice = [vertex_0] for ii in range(vertex_0.size): vertice.append(vertex_0 + emat[:, ii]) return vertice def f(v): ''' Evaluation of Function $f$ ''' dim = v.size v0 = np.ones(dim) * 5 v1 = np.ones(dim) * 3 return 0.5 * np.dot(v - v0, v - v1) def line(t, v1, v2): return (1 - t) * v1 + t * v2
def simplex(f, vertice, maxit=1000, step_length=100, tol=1e-3): vertice_max_list = [] # store the max vertex during each iteration vertice_min_list = [] # store the min vertex during each iteration for jj in range(maxit): y = [] for ii in vertice: y.append(f(ii)) y = np.array(y) # only the highest (worst), next-highest, and lowest (best) vertice # are neeed idx = np.argsort(y) vertice_max_list.append(vertice[idx[-1]]) vertice_min_list.append(vertice[idx[0]]) # centroid of the best n vertice # NOTE: the worst vertex should be excluded, but for simplicity we don't do this v_mean = np.mean(vertice) # compute the candidate vertex and corresponding function vaule v_ref = line(-1, v_mean, vertice[idx[-1]]) y_ref = f(v_ref) if y_ref >= y[idx[0]] and y_ref < y[idx[-2]]: # y_0<=y_refvertice[idx[-1]] = v_ref # print('reflection1') elif y_ref < y[idx[0]]: # y_ref v_ref_e = line(-2, v_mean, vertice[idx[-1]]) y_ref_e = f(v_ref_e) if y_ref_e < y_ref: vertice[idx[-1]] = v_ref_e # print('expand') else: vertice[idx[-1]] = v_ref # print('reflection2') elif y_ref >= y[idx[-2]]: if y_ref < y[idx[-1]]: # y_ref v_ref_c = line(-0.5, v_mean, vertice[idx[-1]]) y_ref_c = f(v_ref_c) if y_ref_c < y_ref: vertice[idx[-1]] = v_ref_c # print('outside contraction') else: # y_ref>=y_{n+1} inside contraction v_ref_c = line(0.5, v_mean, vertice[idx[-1]]) y_ref_c = f(v_ref_c) if y_ref_c < y_ref: vertice[idx[-1]] = v_ref_c # print('inside contraction') continue # shrinkage for ii in range(1, len(vertice)): vertice[ii] = 0.5 * (vertice[0] + vertice[ii]) print('shrinkage') vertice = vertice_init(vertice[idx[0]], step_length) # restart # restarting is very important during iteration, for the simpex # can easily stucked into a nonoptimal point rtol = 2.0 * abs(y[idx[0]] - y[idx[-1]]) / ( abs(y[idx[0]]) + abs(y[idx[-1]]) + 1e-9) if rtol <= tol: vertice = vertice_init(vertice[idx[0]], step_length) return vertice_max_list, vertice_min_list
测试部分,设置未知参数维度为15维,根据未知函数\(f(\mathbf{x})\)的定义易得该函数的最小值为\(-0.5\times 15=-7.5\)。
dim = 15 step_length = 5 v = np.random.randn(dim) vertice = vertice_init(v, step_length) # the chioce of step length is cruical vertice_max_list, vertice_min_list = simplex( f, vertice, maxit=2000, step_length=step_length, tol=1e-5) print('min value is %s' % f(vertice_min_list[-1]))
f_max_list = [] f_min_list=[] for ii,jj in zip(vertice_max_list,vertice_min_list): f_max_list.append(f(ii)) f_min_list.append(f(jj)) plt.plot(f_max_list,'r',linewidth=2,label='max') plt.plot(f_min_list,'b',linewidth=2,label='min') plt.legend(fontsize=15) plt.show()
import numpy as np from matplotlib import pyplot as plt class opt_variables: def __init__(self, input_dim, hidden_dim, output_dim, v): assert (input_dim + 1) * hidden_dim + ( hidden_dim + 1) * output_dim == v.size, 'dimension mismatch!' self.input_dim = input_dim self.hidden_dim = hidden_dim self.output_dim = output_dim self.v = v self.dim = v.size w1_ptr = input_dim * hidden_dim b1_ptr = hidden_dim + w1_ptr w2_ptr = hidden_dim * output_dim + b1_ptr b2_ptr = output_dim + w2_ptr self.ptrs = np.array([w1_ptr, b1_ptr, w2_ptr, b2_ptr]) def __add__(self, other): if type(other) is type(self): assert self.dim == other.dim, 'dimension mismatch!' return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v + other.v) return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v + other) def __iadd__(self, other): if type(other) is type(self): assert self.dim == other.dim, 'dimension mismatch!' self.v += other.v else: self.v += other return self def __sub__(self, other): if type(other) is type(self): assert self.dim == other.dim, 'dimension mismatch!' return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v - other.v) return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v - other) def __isub__(self, other): if type(other) is type(self): assert self.dim == other.dim, 'dimension mismatch!' self.v -= other.v else: self.v -= other return self def __mul__(self, constant): return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v * constant) def __rmul__(self, constant): return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v * constant) def __imul__(self, constant): self.v *= constant return self def __truediv__(self, constant): return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v / constant) def __rtruediv__(self, constant): return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, constant / self.v) def __itruediv__(self, constant): self.v /= constant return self def __str__(self): return '%s' % self.v def vertice_init(vertex_0, step_length): emat = np.eye(vertex_0.dim) * step_length vertice = [vertex_0] for ii in range(vertex_0.dim): vertice.append(vertex_0 + emat[:, ii]) return vertice def sigmoid(x): return 1 / (1 + np.exp(-x)) def f(v, X, y): assert v.ptrs.size == 4, 'dimension mismatch!' assert v.input_dim == X.shape[1], 'dimension mismatch!' w1 = np.reshape(v.v[:v.ptrs[0]], (v.hidden_dim, v.input_dim)) b1 = np.reshape(v.v[v.ptrs[0]:v.ptrs[1]], v.hidden_dim) w2 = np.reshape(v.v[v.ptrs[1]:v.ptrs[2]], (v.output_dim, v.hidden_dim)) b2 = np.reshape(v.v[v.ptrs[2]:], v.output_dim) loss = 0.0 for ii in range(X.shape[0]): loss += ( np.dot(w2, sigmoid(np.dot(w1, X[ii, :]) + b1)) + b2 - y[ii])**2 return loss[0] def pred(v, X): assert v.ptrs.size == 4, 'dimension mismatch!' assert v.input_dim == X.shape[1], 'dimension mismatch!' w1 = np.reshape(v.v[:v.ptrs[0]], (v.hidden_dim, v.input_dim)) b1 = np.reshape(v.v[v.ptrs[0]:v.ptrs[1]], v.hidden_dim) w2 = np.reshape(v.v[v.ptrs[1]:v.ptrs[2]], (v.output_dim, v.hidden_dim)) b2 = np.reshape(v.v[v.ptrs[2]:], v.output_dim) y_pred = [] for ii in range(X.shape[0]): y_pred.append(np.dot(w2, sigmoid(np.dot(w1, X[ii, :]) + b1)) + b2) return np.array(y_pred) def line(t, v1, v2): return (1 - t) * v1 + t * v2 def simplex(f, X, y_real, vertice, maxit=1000, tol=1e-7, step_length=100): vertice_max_list = [] vertice_min_list = [] for jj in range(maxit): y = [] # evaluate the function value for ii in vertice: y.append(f(ii, X, y_real)) y = np.array(y) idx = np.argsort(y) # in descend order vertice_max_list.append(vertice[idx[-1]]) vertice_min_list.append(vertice[idx[0]]) v_mean = np.mean(vertice) v_ref = line(-1, v_mean, vertice[idx[-1]]) y_ref = f(v_ref, X, y_real) if y_ref >= y[idx[0]] and y_ref < y[idx[-2]]: # y_0<=y_refvertice[idx[-1]] = v_ref # print('reflection1') elif y_ref < y[idx[0]]: # y_ref v_ref_e = line(-2, v_mean, vertice[idx[-1]]) y_ref_e = f(v_ref_e, X, y_real) if y_ref_e < y_ref: vertice[idx[-1]] = v_ref_e # print('expand') else: vertice[idx[-1]] = v_ref # print('reflection2') elif y_ref >= y[idx[-2]]: if y_ref < y[idx[-1]]: # y_ref v_ref_c = line(-0.5, v_mean, vertice[idx[-1]]) y_ref_c = f(v_ref_c, X, y_real) if y_ref_c < y_ref: vertice[idx[-1]] = v_ref_c # print('outside contraction') else: # y_ref>=y_{n+1} inside contraction v_ref_c = line(0.5, v_mean, vertice[idx[-1]]) y_ref_c = f(v_ref_c, X, y_real) if y_ref_c < y_ref: vertice[idx[-1]] = v_ref_c # print('inside contraction') continue # shrinkage for ii in range(1, len(vertice)): vertice[ii] = 0.5 * (vertice[0] + vertice[ii]) print('shrinkage') vertice = vertice_init(vertice[idx[0]], step_length) rtol = 2.0 * abs(y[idx[0]] - y[idx[-1]]) / ( abs(y[idx[0]]) + abs(y[idx[-1]]) + 1e-9) if rtol <= tol: vertice = vertice_init(vertice[idx[0]], step_length) return vertice_max_list, vertice_min_list # define the 3 layer NN input_dim = 3 hidden_dim = 2 output_dim = 1 total_dim = (input_dim + 1) * hidden_dim + (hidden_dim + 1) * output_dim # simplex initialize v = opt_variables(input_dim, hidden_dim, output_dim, np.random.rand(total_dim)) step_length = 3 vertice = vertice_init(v, step_length) # the chioce of step length is cruical # training data X = np.random.rand(100, 3) y_real = X.sum(axis=1) # model training vertice_max_list, vertice_min_list = simplex( f, X, y_real, vertice, maxit=200, tol=1e-3, step_length=step_length) # plot f_max_list = [] f_min_list = [] for ii, jj in zip(vertice_max_list, vertice_min_list): f_max_list.append(f(ii, X, y_real)) f_min_list.append(f(jj, X, y_real)) plt.plot(f_max_list, 'r', linewidth=2, label='max') plt.plot(f_min_list, 'b', linewidth=2, label='min') plt.legend(fontsize=15) plt.show() # prediction X_test = np.random.rand(100, 3) y_real_test = X_test.sum(axis=1) y_pred = pred(vertice_min_list[-1], X_test) plt.plot(y_real_test, 'r', linewidth=2, label='real') plt.plot(y_pred, 'b', linewidth=2, label='pred') plt.legend(fontsize=15) plt.show()