pytorch实战-求函数极值

以函数himmelhlau为例,求函数极小值:

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

def himmelblau(t):
	return (t[0] ** 2 + t[1] - 11) ** 2 + (t[0] + t[1] ** 2 - 7) ** 2

x = np.arange(-6, 6, 0.1)
y = np.arange(-6, 6, 0.1)
X, Y = np.meshgrid(x, y)
Z = himmelblau([X, Y])

fig = plt.figure()
ax = fig.gca(projection = '3d')
ax.plot_surface(X, Y, Z)
ax.view_init(60, -30)
ax.set_xlabel('x')
ax.set_ylabel('y')
# fig.show()
# plt.show()

可视化效果如下:
pytorch实战-求函数极值_第1张图片
对其进行求极小值:

from picture_3D import himmelblau
import torch

x = torch.tensor([0., 0.], requires_grad = True)
optimizer = torch.optim.Adam([x,])
for step in range(20001):
	if step:
		optimizer.zero_grad()
		f.backward(retain_graph = True)
		optimizer.step()
	f = himmelblau(x)
	if step % 1000 == 0:
		print ('step:{} , x = {} , value = {}'.format(step, x.tolist(), f))

结果如下:

step:0 , x = [0.0, 0.0] , value = 170.0
step:1000 , x = [1.270142912864685, 1.118398904800415] , value = 88.42723083496094
step:2000 , x = [2.332378387451172, 1.9535709619522095] , value = 13.730920791625977
step:3000 , x = [2.8519949913024902, 2.114161729812622] , value = 0.6689225435256958
step:4000 , x = [2.981964111328125, 2.0271568298339844] , value = 0.014858869835734367
step:5000 , x = [2.9991261959075928, 2.0014777183532715] , value = 3.956971340812743e-05
step:6000 , x = [2.999983549118042, 2.0000221729278564] , value = 1.1074007488787174e-08
step:7000 , x = [2.9999899864196777, 2.000013589859009] , value = 4.150251697865315e-09
step:8000 , x = [2.9999938011169434, 2.0000083446502686] , value = 1.5572823031106964e-09
step:9000 , x = [2.9999964237213135, 2.000005006790161] , value = 5.256879376247525e-10
step:10000 , x = [2.999997854232788, 2.000002861022949] , value = 1.8189894035458565e-10
step:11000 , x = [2.9999988079071045, 2.0000014305114746] , value = 5.547917680814862e-11
step:12000 , x = [2.9999992847442627, 2.0000009536743164] , value = 1.6370904631912708e-11
step:13000 , x = [2.999999523162842, 2.000000476837158] , value = 5.6843418860808015e-12
step:14000 , x = [2.999999761581421, 2.000000238418579] , value = 1.8189894035458565e-12
step:15000 , x = [3.0, 2.0] , value = 0.0
step:16000 , x = [3.0, 2.0] , value = 0.0
step:17000 , x = [3.0, 2.0] , value = 0.0
step:18000 , x = [3.0, 2.0] , value = 0.0
step:19000 , x = [3.0, 2.0] , value = 0.0
step:20000 , x = [3.0, 2.0] , value = 0.0
[Finished in 9.1s]

你可能感兴趣的:(pytorch实战-求函数极值)