Before diving in to Haskell, let’s go over exactly what the conjugate gradient method is and why it works. The “normal” conjugate gradient method is a method for solving systems of linear equations. However, this extends to a method for minimizing quadratic functions, which we can subsequently generalize to minimizing arbitrary functions f:Rn→R . We will start by going over the conjugate gradient method of minimizing quadratic functions, and later generalize.
Suppose we have some quadratic function
We can write any quadratic function in this form, as this generates all the coefficients xixj as well as linear and constant terms. In addition, we can assume that A=AT (A is symmetric). (If it were not, we could just rewrite this with a symmetric A , since we could take the term for xixj and the term for xjxi , sum them, and then have Aij=Aji both be half of this sum.)
Taking the gradient of f , we obtain
If we evaluate −∇f at any given location, it will give us a vector pointing towards the direction of steepest descent. This gives us a natural way to start our algorithm - pick some initial guess x0 , compute the gradient −∇f(x0) , and move in that direction by some step size α . Unlike normal gradient descent, however, we do not have a fixed step size α - instead, we perform a line search in order to find the best α . This α is the value of α which brings us to the minimum of f if we are constrainted to move in the direction given by d0=−∇f(x0) .
Note that computing α is equivalent to minimizing the function
The minimum of this function occurs when g′(α)=0 , that is, when
We define two vectors x and y to be conjugate with respect to some semi-definite matrix A if xTAy=0 . (Semi-definite matrices are ones where xTAx≥0 for all x , and are what we require for conjugate gradient.)
Since we have already moved in the d0=−∇f(x0) direction, we must find a new direction d1 to move in that is conjugate to d0 . How do we do this? Well, let’s compute d1 by starting with the gradient at x1 and then subtracting off anything that would counter-act the previous direction:
Thus, the full Conjugate Gradient algorithm for quadratic functions:
Let f be a quadratic function f(x)=12xTAx+bTx+c which we wish to minimize.
1. Initialize: Let i=0 and xi=x0 be our initial guess, and compute di=d0=−∇f(x0) .
2. Find best step size: Compute α to minimize the function f(xi+αdi) via the equation
α=−diT(Axi+b)diTAdi.
3. Update the current guess: Let xi+1=xi+αdi .
4. Update the direction: Let di+1=−∇f(xi+1)+βidi where βi is given by
βi=∇f(xi+1)TAdidiTAdi.
5. Iterate: Repeat steps 2-4 until we have looked in n directions, where nn is the size of your vector space (the dimension of x ).
So, now that we’ve derived this for quadratic functions, how are we going to use this for general nonlinear optimization of differentiable functions? To do this, we’re going to reformulate the above algorithm in slightly more general terms.
First of all, we will revise step two. Instead of
Find best step size: Compute α to minimize the function f(xi+αdi) via the equation
α=−diT(Axi+b)diTAdi.
we will simply use a line search:
Find best step size: Compute α to minimize the function f(xi+αdi) via a line search in the direction di .
In addition, we must reformulate the computation of βi . There are several ways to do this, all of which are the same in the quadratic case but are different in the general nonlinear case. We reformulate this computation by generalizing. Note that the difference between xk+1 and xk is entirely in the direction dk , so that for some constant c , xk+1−xk=cdk . Since ∇f(x)=Ax+b ,
Update the direction: Let dk+1=−∇f(xk+1)+βkdk where βk is given by
βk=∇f(xk+1)T(∇f(xk+1)−∇f(xk))dkT(∇f(xk+1)−∇f(xk)).
We can now apply this algorithm to any nonlinear and differentiable function! This reformulation of β is known as the Polak-Ribiere method; know that there are others, similar in form and also in use.
The one remaining bit of this process that we haven’t covered is step two: the line search. As you can see above, we are given a point x , some vector v , and a multivariate function f:Rn→R , and we wish to find the α which minimizes f(x+αv) . Note that a line search can be viewed simply as root finding, since we know that v⋅∇f(x+αv) should be zero at the minimum. (Since if it were non-zero, we could move from that minimum to a better location.)
There are many ways to do this line search, and they can range from relatively simple linear methods (like the secant method) to more complex (using quadratic or cubic polynomial approximations).
One simple method for a line search is known as the bisection method. The bisection method is simply a binary search. To minimize a univariate function g(x) , it begins with two points, aa and bb, such that g(a) and g(b) have opposite signs. By the intermediate value theorem, g(x) must have a root in [a,b] . (Note that in our case, g(α)=v⋅∇f(x+αv) .) It then computes their midpoint, c=a+b2 , and evaluates the function g to compute g(c) . If g(a) and g(c) have opposite signs, the root must be in [a,c] ; if g(c) and g(b) have opposite signs, then [c,b] must have the root. At this point, the method recurses, continuing its search until it has gotten close enough to the true α .
Another simple method is known as the secant method. Like the bisection method, the secant method requires two initial points a and b such that g(a) and g(b) have opposite signs. However, instead of doing a simple binary search, it does linear interpolation. It finds the line between (a,g(a)) and (b,g(b)) :
There are more line search methods, but the last one we will examine is one known as Dekker’s method. Dekker’s method is a combination of the secand method and the bisection method. Unlike the previous two methods, Dekker’s method keeps track of three points:
Brent’s method then computes the two possible next values: m (by using the bisection method) and s (by using the secant method with bk and bk−1 ). (On the very first iteration, bk−1=ak and it uses the bisection method.) If the secant method result ss lies between bk and m , then let bk+1=s ; otherwise, let bk+1=m .
After bk+1 is chosen, it is checked to for convergence. If the method has converged, iteration is stopped. If not, the method continues. A new contrapoint ak+1 is chosen such that bk+1 and ak+1 have opposite signs. The two choices for ak+1 are either for it to remain unchanged (stay ak ) or for it to become bk - the choice depends on the signs of the function values involved. Before repeating, the values of f(ak+1) and f(bk+1) are examined, and bk+1 is swapped with ak+1 if it has a higher function value. Finally, the method repeats with the new values of ak,bk , and bk−1 .
Dekker’s method is effectively a heuristic method, but is nice in practice; it has the reliability of the bisection method and gains a boost of speed from its use of the secant method.