Logistic regression example
This page works through an example of fitting a logistic model with the iteratively-reweighted least squares (IRLS) algorithm. If you'd like to examine the algorithm in more detail, here is Matlab code together with a usage example. (See also old code.)A logistic model predicts a binary output y from real-valued inputs x according to the rule:
p(y) = g(x.w)where w is a vector of adjustable parameters. That is, the probability that y=1 is determined as a linear function of x, followed by a nonlinear monotone function (called the link function) which makes sure that the probability is between 0 and 1. The logistic model is an example of a generalized linear model or GLIM; other GLIMs differ only in that they have different link functions.
g(z) = 1 / (1 + exp(-z))
The IRLS algorithm is Newton's method applied to the problem of maximizing the likelihood of some outputs y given corresponding inputs x. It is an iterative algorithm; it starts with a guess at the parameter vector w, and on each iteration it solves a weighted least squares problem to find a new parameter vector.
Here is an example of a logistic regression problem with one input and one output:
We are predicting the species of an iris (either I. versicolor, which we have coded as y=0, or I. virginica, which we have coded as y=1) from the length of one of its petals (on the x axis, in cm). The crosses are our training data, which are measurements of the petals of irises whose species is known. The monotonically increasing curve is our prediction: given a new petal measurement, what is the probability that it came from an I. virginica? (This is not the maximum likelihood prediction curve; instead it is taken from one of the middle iterations of IRLS, before it has converged.) The other curve is the estimated standard deviation of y. If our predicted probability is p, then our predicted variance is p(1-p). (It turns out that, in general, the variance is related to the derivative of the link function g'(w.x).)
At every iteration, IRLS builds and solves a weighted linear regression problem whose weights are the standard deviations of the training points. Here is an example of such a problem:
The straight line is the linear portion of our prediction; if we were to apply the link function g to the height of each point on the line, we would get the prediction curve in the previous picture. The crosses are our training data again; the x values are the same, but the y values have been adjusted by the process described below so that they lie closer to a straight line.
An adjusted y value depends on several things: the original y value, the linear part of our prediction z=x.w, our prediction p=g(z), and the derivative v=g'(z). It is given by the formula
adjusted_y = z + (y - p) / vWe can interpret this formula as stretching out the prediction error (y-p) according to the inverse variance: prediction errors on low-variance points become more important than prediction errors on high-variance points. (This effect is partly counteracted by the lower weights of the low-variance points, but only partly.) We can derive the formula by setting the derivative of our log likelihood to zero and performing a Taylor expansion of the resulting equations around our current estimate of w; rearranging terms in this Taylor expansion yields a set of normal equations in which the dependent variables are given by the above formula.
To summarize, the IRLS algorithm is Newton's method for fitting a GLIM by maximum likelihood. It repeatedly updates a guess at the parameter vector by forming a weighted least squares problem. The x values in this WLS problem are taken straight from the training data; the y values are adjusted from the training data according to the formula above; and the weights are sqrt(g'(x.w)).
This page was written by Geoff Gordon and last updated June 5, 2002.