用pytorch做简单的最优化问题

pytorch的自动求导很好用,可以利用它对一些求导困难的问题做一些最优化问题,比如昨天狗菜提了一个问题:

求一个三维点的位置,使得它到一个直线族(三维)的距离之和最小

实际上就是求如下最优化问题
m i n ∑ i ∣ A i x + B i y + C i z + D i ∣ A i 2 + B i 2 + C i 2 + C i 2 min\sum_i \frac{|A_ix+B_iy+C_iz+D_i|}{\sqrt{ A^2_i+B^2_i+C^2_i+C^2_i}} miniAi2+Bi2+Ci2+Ci2 Aix+Biy+Ciz+Di

这个含绝对值的求导不太好做,如果交给Pytorch的话就比较容易了,构造一个简单的例子,即过(1,1,1)的三条直线,显然,最优解就是(1,1,1):

from torch.autograd import Variable
import torch

#Ax+By+Cz+(-A-B-C)=0
tmp1 = torch.rand(3)
a1,b1,c1 = tmp1
d1 = -sum(tmp1)

tmp1 = torch.rand(3)
a2,b2,c2 = tmp1
d2 = -sum(tmp1)

tmp1 = torch.rand(3)
a3,b3,c3 = tmp1
d3 = -sum(tmp1)


x = torch.rand(1)
y = torch.rand(1)
z = torch.rand(1)
x = Variable(x,requires_grad=True)
y = Variable(y,requires_grad=True)
z = Variable(z,requires_grad=True)

lr = 0.0001
for i in range(20000):
    l1 = abs(a1*x+b1*y+c1*z+d1)/torch.sqrt(a1**2+b1**2+c1**2+d1**2)
    l2 = abs(a2*x+b2*y+c2*z+d2)/torch.sqrt(a2**2+b2**2+c2**2+d2**2)
    l3 = abs(a3*x+b3*y+c3*z+d3)/torch.sqrt(a3**2+b3**2+c3**2+d3**2)
    loss = l1 + l2 + l3
    loss.backward()
    x.data = x.data - lr*x.grad.data
    y.data = y.data - lr*y.grad.data
    z.data = z.data - lr*z.grad.data
    x.grad.data.zero_()
    y.grad.data.zero_()#重要
    z.grad.data.zero_()
x,y,z   

最后的结果是:

(tensor([1.0004], requires_grad=True),
 tensor([0.9996], requires_grad=True),
 tensor([0.9999], requires_grad=True))

同样的框架可以用来做很多简单的最优化问题

你可能感兴趣的:(算法学习)