最小均方算法(LMS Algorithm)理论及DSP实现

最小均方算法(LMS Algorithm)理论及DSP实现

 LMS算法可认为是机器学习里面最基本也比较有用的算法,神经网络中对参数的学习使用的就是LMS的思想,在通信信号处理领域LMS也非常常见,比如自适应滤波器。

本文主要对LMS(Least Mean Square)算法进行简单的整理,包括内容:

(1)理论上介绍基于LMS的梯度下降算法(包括BACH/STOCHASTIC),给出一个matlab的实现

(2)DSP上的实现,主要使用C语言


1. LMS算法理论

问题引出

因为本人感兴趣的领域为机器学习,因此这里先说明下学习的过程,给定这样一个问题:某地的房价与房地面积和卧室的数量之间成如下表的关系,

Living area (feet2)       #bedrooms          Price (1000$s)
2104                              3                           400
1600                              3                           330
2400                              3                           369
1416                              2                           232
3000                              4                           540

据此,我们要通过分析上面的数据学习出一个模型,用于预测其它情况(比如面积2000,卧室数5)的房价。这就是一个学习问题,更简洁的说,就是一个概率里的回归问题。这里固定几个符号:x表示输入([Living area,bedrooms]),y表示输出(Price),h表示要学习的模型,m表示输入每个数据维度(这里是2),n表示输入数据的个数(这里是5)。

该学习过程的可以描述如下图,

最小均方算法(LMS Algorithm)理论及DSP实现_第1张图片
.
h必定与面积和卧室数相关,.这里不考虑复杂的情况,假设模型是线性的(实际其它问题中很可能是其它关系模型,比如exp
.
.令x1=1,则。这里,我们考虑上面的房价问题,还是将w0忽略。

为了获得h(x),现在的问题是什么呢?那就是:怎样获得h(x)的w1~w2的值。

我们再对问题进行描述:

已知——上面的数据表格,线性模型(不知道参数)

求解——参数w1~w2

引入一个函数,叫损失函数


就是最小二乘法中计算误差的函数,只是前面添加了1/2,表示什么意思呢?损失函数越小,说明模型与当前已知数据的拟合程度越好,否则越差。因此,求解w1~w2的目标就是求解J(w)最小,这就用到了LMS算法。


LMS算法

LMS算法是一个搜索算法,假设w从某个给定的初始值开始迭代,逐渐使J(W)朝着最小的方向变化,直到达到一个值使J(w)收敛。考虑梯度下降算法(gradient descent algorithm),它通过给定的w值快速的执行如下的更新操作:


其中为学习率(Learning rate)。

要对w更新,首先需要完成上面的求导,求导的结果参见下面的算法流程。

对一个单一的训练实例j,



按照上述的更新方法,对多个实例的更新规则为

Repeat until convergence {

        for every j, exec

                  

}

这种更新的梯度下降方法称为batch gradient descent。还有一种更新的方式:采用随机的样本数据实例,如下

Repeat until convergence {

        for every j, exec

                  

}

这种方法称为stochastic gradient descent (或者incremental gradient descent)。

两种方法的明显区别是batch的训练时间要比stochastic常,但效果可能更好。实际问题中,因为我们只需要找到一个接近使J(w)最小的值即可,因此stochastic更常用。


说了这么久,LMS到底能用来干嘛,其实上面已经很清楚了:参数训练中的求极值


在matlab上对stochastic gradient descent 的实现如下:

[plain] view plaincopyprint?
  1. function [test_targets, a, updates] = LMS(train_patterns, train_targets, test_patterns, params)  
  2.   
  3. % Classify using the least means square algorithm  
  4. % Inputs:  
  5. %   train_patterns  - Train patterns  
  6. %   train_targets   - Train targets  
  7. %   test_patterns   - Test  patterns  
  8. %   param           - [Maximum iteration Theta (Convergence criterion), Convergence rate]  
  9. %  
  10. % Outputs  
  11. %   test_targets    - Predicted targets  
  12. %   a               - Weights vector  
  13. %   updates         - Updates throughout the learning iterations  
  14. %  
  15. % NOTE: Suitable for only two classes  
  16. %  
  17.   
  18. [c, n]                  = size(train_patterns);  
  19. [Max_iter, theta, eta]  = process_params(params);  
  20.   
  21. y               = [train_patterns ; ones(1,n)];  
  22. train_zero      = find(train_targets == 0);  
  23.   
  24. %Preprocessing  
  25. processed_patterns               = y;  
  26. processed_patterns(:,train_zero) = -processed_patterns(:,train_zero);  
  27. b                                = 2*train_targets - 1;   
  28.   
  29. %Initial weights  
  30. a               = sum(processed_patterns')';  
  31. iter            = 1;  
  32. k               = 0;  
  33. update          = 1e3;  
  34. updates         = 1e3;  
  35.   
  36. while ((sum(abs(update)) > theta) & (iter < Max_iter))  
  37.     iter = iter + 1;  
  38.       
  39.     %k <- (k+1) mod n  
  40.     k = mod(k+1,n);  
  41.     if (k == 0),   
  42.         k = n;  
  43.     end  
  44.       
  45.     % a <- a + eta*(b-a'*Yk)*Yk'  
  46.     update  = eta*(b(k) - a'*y(:,k))*y(:,k);  
  47.     a       = a + update;  
  48.       
  49.     updates(iter) = sum(abs(update));  
  50. end  
  51.   
  52. if (iter == Max_iter),  
  53.     disp(['Maximum iteration (' num2str(Max_iter) ') reached']);  
  54. else  
  55.     disp(['Did ' num2str(iter) ' iterations'])  
  56. end  
  57.   
  58. %Classify the test patterns  
  59. test_targets = a'*[test_patterns; ones(1, size(test_patterns,2))];  
  60.   
  61. test_targets = test_targets > 0;  

2. 基于LMS的梯度下降算法在DSP上的实现

下面是我在DSP6713上使用软件仿真实现的LMS算法,

[cpp] view plaincopyprint?
  1. /* 
  2.  * zx_lms.h 
  3.  * 
  4.  *  Created on: 2013-8-4 
  5.  *      Author: monkeyzx 
  6.  */  
  7.   
  8. #ifndef ZX_LMS_H_  
  9. #define ZX_LMS_H_  
  10.   
  11. /* 
  12.  * methods for @lms_st.method 
  13.  */  
  14. #define STOCHASTIC           (0x01)     /* 随机梯度下降 */  
  15. #define BATCH                (0x02)     /* BATCH梯度下降 */  
  16.   
  17. struct lms_st {  
  18.     short method;       /* 0/1 */  
  19.     double *x;          /* features, x0,...,x[n-1] */  
  20.     int n;              /* dimension of features */  
  21.     double *y;          /* given output, y0,..,y[m-1] */  
  22.     int m;              /* number of data set */  
  23.     double *weight;     /* weighs that want to train by using LMS, w0,w1,..,w[n-1] */  
  24.     double lrate;       /* learning rate */  
  25.     double threshhold;  /* if error < threshold, stop iteration */  
  26.     int max_iter;       /* if iter numbers > max_iter, stop iteration, 
  27.                            if max_iter<0, then max_iter is unused */  
  28. };  
  29.   
  30. extern void zx_lms(void);  
  31.   
  32. #endif /* ZX_LMS_H_ */  

[cpp] view plaincopyprint?
  1. /* 
  2.  * zx_lms.c 
  3.  * Least Mean Squares Algorithm 
  4.  *  Created on: 2013-8-4 
  5.  *      Author: monkeyzx 
  6.  */  
  7. #include "zx_lms.h"  
  8. #include "config.h"  
  9. #include <stdio.h>  
  10. #include <stdlib.h>  
  11.   
  12. static double init_y[] = {4.00,3.30,3.69,2.32};  
  13. static double init_x[] = {  
  14.         2.104,3,  
  15.         1.600,3,  
  16.         2.400,3,  
  17.         3.000,4  
  18. };  
  19. static double weight[2] = {0.1, 0.1};  
  20.   
  21. /* 
  22.  * Least Mean Square Algorithm 
  23.  * return value @error when stop iteration 
  24.  * use @lms_prob->method to choose a method. 
  25.  */  
  26. double lms(struct lms_st *lms_prob)  
  27. {  
  28.     double err;  
  29.     double error;  
  30.     int i = 0;  
  31.     int j = 0;  
  32.     int iter = 0;  
  33.     static double *h = 0;       /* 加static,防止栈溢出*/  
  34.   
  35.     h = (double *)malloc(sizeof(double) * lms_prob->m);  
  36.     if (!h) {  
  37.         return -1;  
  38.     }  
  39.     do {  
  40.         error = 0;  
  41.   
  42.         if (lms_prob->method != STOCHASTIC) {  
  43.             i = 0;  
  44.         } else {  
  45.             /* i=(i+1) mod m */  
  46.             i = i + 1;  
  47.             if (i >= lms_prob->m) {  
  48.                 i = 0;  
  49.             }  
  50.         }  
  51.   
  52.         for ( ; i<lms_prob->m; i++) {  
  53.             h[i] = 0;  
  54.             for (j=0; j<lms_prob->n; j++) {  
  55.                 h[i] += lms_prob->weight[j] * lms_prob->x[i*lms_prob->n+j]; /* h(x) */  
  56.             }  
  57.             if (lms_prob->method == STOCHASTIC) break;   /* handle STOCHASTIC */  
  58.         }  
  59.   
  60.         for (j=0; j<lms_prob->n; j++) {  
  61.             if (lms_prob->method != STOCHASTIC) {  
  62.                 i = 0;  
  63.             }  
  64.             for ( ; i<lms_prob->m; i++) {  
  65.                 err = lms_prob->lrate  
  66.                         * (lms_prob->y[i] - h[i]) * lms_prob->x[i*lms_prob->n+j];  
  67.                 lms_prob->weight[j] += err;            /* Update weights */  
  68.                 error += ABS(err);  
  69.                 if (lms_prob->method == STOCHASTIC) break/* handle STOCHASTIC */  
  70.             }  
  71.         }  
  72.   
  73.         iter = iter + 1;  
  74.         if ((lms_prob->max_iter > 0) && ((iter > lms_prob->max_iter))) {  
  75.             break;  
  76.         }  
  77.     } while (error >= lms_prob->threshhold);  
  78.   
  79.     free(h);  
  80.   
  81.     return error;  
  82. }  
  83.   
  84. #define DEBUG  
  85. void zx_lms(void)  
  86. {  
  87.     int i = 0;  
  88.     double error = 0;  
  89.     struct lms_st lms_prob;  
  90.   
  91.     lms_prob.lrate = 0.01;  
  92.     lms_prob.m = 4;  
  93.     lms_prob.n = 2;  
  94.     lms_prob.weight = weight;  
  95.     lms_prob.threshhold = 0.2;  
  96.     lms_prob.max_iter = 1000;  
  97.     lms_prob.x = init_x;  
  98.     lms_prob.y = init_y;  
  99. //  lms_prob.method = STOCHASTIC;  
  100.     lms_prob.method = BATCH;  
  101.   
  102. //  error = lms(init_x, 2, init_y, 4, weight, 0.01, 0.1, 1000);  
  103.     error = lms(&lms_prob);  
  104.   
  105. #ifdef DEBUG  
  106.     for (i=0; i<sizeof(weight)/sizeof(weight[0]); i++) {  
  107.         printf("%f\n", weight[i]);  
  108.     }  
  109.     printf("error:%f\n", error);  
  110. #endif  
  111. }  

输入、输出、初始权值为

static double init_y[] = {4.00,3.30,3.69,2.32};
static double init_x[] = {        /* 用一维数组保存 */
2.104, 3,
1.600, 3,
2.400, 3,
3.000, 4
};
static double weight[2] = {0.1, 0.1};

main函数中只需要调用zx_lms()就可以运行了,本文对两种梯度下降方法做了个简单对比,

max_iter=1000 w1 w2 error CPU Cycles
batch -0.6207369 1.419737 0.20947 2181500
stochastic 
0.145440 0.185220 0.130640 995

需要说明的是:batch算法是达到最大迭代次数1000退出的,而stochastic是收敛退出的,因此这里batch算法应该没有对数据做到较好的拟合。stochastic算法则在时钟周期上只有995,远比batch更有时间上的优势。

注:这里的error没有太大的可比性,因为batch的error针对的整体数据集的error,而stochastic 的error是针对一个随机的数据实例。


LMS有个很重要的问题:收敛。开始时可以根据给定数据集设置w值,使h(x)尽可能与接近y,如果不确定可以将w设置小一点。


这里顺便记录下在调试过程中遇到的一个问题:在程序运行时发现有变量的值为1.#QNAN

解决:QNAN是Quiet Not a Number简写,是常见的浮点溢出错误,在网上找到了解释

QNaN is a NaN with the most significant fraction bit set. QNaN’s propagate freely through most arithmetic operations. These values pop out of an operation when the result is not mathematically defined.

在开始调试过程中因为迭代没有收敛,发散使得w和error等值逐渐累积,超过了浮点数的范围,从而出现上面的错误,通过修改使程序收敛后上面的问题自然而然解决了。



参考:

[1] Andrew Ng的机器学习课程

[2] Richard O.Duda 等,《模式分类》 

你可能感兴趣的:(最小均方算法(LMS Algorithm)理论及DSP实现)