自组织神经网络SOM原理——结合例子MATLAB实现

  本文主要内容为SOM神经网络原理的介绍,并结合实例给出相应的MATLAB代码实现,方便初学者接触学习,本人才疏学浅,如有纰漏,还望各路大神积极指点。

一、SOM神经网络介绍

     自组织映射神经网络, 即Self Organizing Maps (SOM), 可以对数据进行无监督学习聚类。它的思想很简单,本质上是一种只有输入层--隐藏层的神经网络。隐藏层中的一个节点代表一个需要聚成的类。训练时采用“竞争学习”的方式,每个输入的样例在隐藏层中找到一个和它最匹配的节点,称为它的激活节点,也叫“winning neuron”。 紧接着用随机梯度下降法更新激活节点的参数。同时,和激活节点临近的点也根据它们距离激活节点的远近而适当地更新参数。

      所以,SOM的一个特点是,隐藏层的节点是有拓扑关系的。这个拓扑关系需要我们确定,如果想要一维的模型,那么隐藏节点依次连成一条线;如果想要二维的拓扑关系,那么就行成一个平面,如下图所示(也叫Kohonen Network):

      既然隐藏层是有拓扑关系的,所以我们也可以说,SOM可以把任意维度的输入离散化到一维或者二维(更高维度的不常见)的离散空间上。 Computation layer里面的节点与Input layer的节点是全连接的。

拓扑关系确定后,开始计算过程,大体分成几个部分:

1) 初始化:每个节点随机初始化自己的参数。每个节点的参数个数与Input的维度相同。

2)对于每一个输入数据,找到与它最相配的节点。假设输入时D维的, 即 X={x_i, i=1,...,D},那么判别函数可以为欧几里得距离:

3) 找到激活节点I(x)之后,我们也希望更新和它临近的节点。令S_ij表示节点i和j之间的距离,对于I(x)临近的节点,分配给它们一个更新权重:

简单地说,临近的节点根据距离的远近,更新程度要打折扣。

4)接着就是更新节点的参数了。按照梯度下降法更新:

迭代,直到收敛。



二、问题描述

用26个英文字母作为SOM输入样本。每个字符对应一个5维向量,各字符与向量的关系如表4-2所示。由表4-2可以看出,代表A、B、C、D、E的各向量中有4个分量相同,即,因此,A、B、C、D、E应归为一类;代表F、G、H、I、J的向量中有3个分量相同,同理也应归为一类;依此类推。这样就可以由表4-2中输入向量的相似关系,将对应的字符标在图4-8所示的树形结构图中。用SOM网络对其他进行聚类分析。




三、MATLAB代码实现

SOM_mian.m

[html]  view plain  copy
  1. %%% 神经网络之自组织网络SOM练习  
  2. %%%作者:xd.wp  
  3. %%%时间:2016.10.02 19:16  
  4. %% 程序说明:  
  5. %%%          1、本程序中,输出层为二维平面,  
  6. %%%          2、几何邻域确定及调整权值采用exp(-distant^2/delta^2)函数  
  7. %%%          3、样本维数为5,输出层结点为70  
  8. %%%          4、输入数据,归一化为单位向量  
  9. clear all;  
  10. clc;  
  11. %% 网络初始化及相应参数初始化  
  12. %加载数据并归一化  
  13. [train_data,train_label]=SOM_data_process();  
  14. data_num=size(train_data,2);  
  15.   
  16. %权值初始化  
  17. weight_temp=ones(5,70)/1000;  
  18. weight_temp=rand(5,70)/1000;  
  19.   
  20. %结点个数  
  21. node_num=size(weight_temp,2);  
  22.   
  23. %权值归一化  
  24. for i=1:node_num  
  25.     weight(:,i)=weight_temp(:,i)/max(weight_temp(:,i));      
  26. end  
  27.   
  28. %邻域函数参数  
  29. delta=2;  
  30.   
  31. %调整步幅  
  32. alpha=0.6;  
  33. %% Kohonen算法学习过程  
  34. for t=4:-1:1                                    %%总体迭代次数  
  35.     index_active=ones(1,node_num);              %%结点活跃标志  
  36.     for n=1:data_num                            %%每个样本的输入  
  37.         % 竞争部分,根据最小距离确定获胜神经元  
  38.         [j_min]=SOM_compare(weight,train_data(:,n),node_num,index_active);  
  39.           
  40.         %去激活,确保数据结点1对1映射  
  41.         index_active(1,j_min)=0;  
  42.           
  43.         %为后续绘图部分服务  
  44.         index_plot(1,n)=j_min;  
  45.         [x,y]=line_to_array(j_min);  
  46.         fprintf('坐标[%d,%d]处为字符%s \n',x,y,train_label(1,n));  
  47.           
  48.         % 学习部分网络权值调整  
  49.         st=num2str(t-1);  
  50.         switch   st  
  51.             case '3'  
  52.                 [weight]=SOM_neighb3(weight,train_data(:,n),j_min,delta,alpha);  
  53.             case '2'  
  54.                 [weight]=SOM_neighb2(weight,train_data(:,n),j_min,delta,alpha);  
  55.             case '1'  
  56.                 [weight]=SOM_neighb1(weight,train_data(:,n),j_min,delta,alpha);  
  57.             otherwise  
  58.                 [weight]=SOM_neighb0(weight,train_data(:,n),j_min,alpha);  
  59.         end  
  60.           
  61.     end  
  62. end  
  63. %% 绘制结点分布图像  
  64. figure(1);  
  65. for n=1:data_num  
  66.     [x,y]=line_to_array(index_plot(1,n));  
  67.     axis([0,12,0,12]);  
  68.     text(x,y,'*');  
  69.     text(x+0.2,y+0.2,train_label(1,n));  
  70.     hold on;  
  71. end  
SOM_data_process.m

[html]  view plain  copy
  1. function [train_data,train_label]=SOM_data_process()  
  2. train_data=[1 0 0 0 0;  
  3.             2 0 0 0 0;  
  4.             3 0 0 0 0;  
  5.             4 0 0 0 0;  
  6.             5 0 0 0 0;  
  7.             3 1 0 0 0;  
  8.             3 2 0 0 0;  
  9.             3 3 0 0 0;  
  10.             3 4 0 0 0;  
  11.             3 5 0 0 0;  
  12.             3 3 1 0 0;  
  13.             3 3 2 0 0;  
  14.             3 3 3 0 0;  
  15.             3 3 4 0 0;  
  16.             3 3 5 0 0;  
  17.             3 3 3 1 0;  
  18.             3 3 3 2 0;  
  19.             3 3 3 3 0;  
  20.             3 3 3 4 0;  
  21.             3 3 3 5 0;  
  22.             3 3 3 3 1;  
  23.             3 3 3 3 2;  
  24.             3 3 3 3 3;  
  25.             3 3 3 3 4;  
  26.             3 3 3 3 5;  
  27.             3 3 3 3 6];  
  28. train_label=['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','1','2','3','4','5','6'];  
  29. train_data=train_data';  
  30. length=size(train_data,2);  
  31. for i=1:length  
  32.      train_data(:,i)=train_data(:,i)/sqrt(sum(train_data(:,i).*train_data(:,i)));  
  33. % train_data(:,i)=train_data(:,i)/max(train_data(:,i));  
  34. end  
  35. end  

SOM_compare.m

[html]  view plain  copy
  1. function [j_min]=SOM_compare(weight,train_data_active,node_num,index_active)  
  2. for j=1:node_num  
  3.     distant(j,1)=sum((weight(:,j)-train_data_active).^2);  
  4. end  
  5. [~,j_min]=min(distant);  
  6. while(index_active(1,j_min)==0)  
  7.     distant(j_min,1)=10000000;  
  8.     [~,j_min]=min(distant);  
  9. end  
  10.   
  11. end  

SOM_neighb3.m
[html]  view plain  copy
  1. function [weight]=SOM_neighb3(weight,train_data_active,j_min,delta,alpha)  
  2.   
  3. %% 权值调整幅度分布  
  4. %                          -0.2  
  5. %                           0.2  
  6. %                           0.6  
  7. %        -0.2   0.2   0.6    1    0.6   0.2   -0.2  
  8. %                           0.6  
  9. %                           0.2  
  10. %                          -0.2  
  11. % 单位距离转化比例为0.4  
  12. %% 坐标转换  
  13. [x,y]=line_to_array(j_min);  
  14. % 将1*70向量中的坐标转化为7*10矩阵中的坐标  
  15. %    1   8    ···  
  16. %    7   14   ···  
  17.   
  18. %% 权值调整过程  
  19. %结点靠上边情况  
  20. if (x<=3)  
  21.     for m=1:1:x+3  
  22.         if (y<=3)          %结点靠左边  
  23.             for n=1:1:y+3  
  24.                 distant=sqrt((x-m)^2+(y-n)^2);  
  25.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  26.             end  
  27.         elseif (y>=8)      %结点靠右边  
  28.             for n=y-3:1:10  
  29.                 distant=sqrt((x-m)^2+(y-n)^2);  
  30.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  31.             end  
  32.         else  
  33.             for n=y-3:1:y+3  
  34.                 distant=sqrt((x-m)^2+(y-n)^2);  
  35.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  36.             end  
  37.         end  
  38.     end  
  39.     %结点靠下边情况  
  40. elseif (x>=5)  
  41.     for m=x-3:1:7  
  42.          if (y<=3)          %结点靠左边  
  43.             for n=1:1:y+3  
  44.                 distant=sqrt((x-m)^2+(y-n)^2);  
  45.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  46.             end  
  47.         elseif (y>=8)      %结点靠右边  
  48.             for n=y-3:1:10  
  49.                 distant=sqrt((x-m)^2+(y-n)^2);  
  50.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  51.             end  
  52.         else  
  53.             for n=y-3:1:y+3  
  54.                 distant=sqrt((x-m)^2+(y-n)^2);  
  55.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  56.             end  
  57.         end  
  58.     end  
  59.     %结点正好在中间  
  60. else  
  61.     for m=1:7  
  62.          if (y<=3)          %结点靠左边  
  63.             for n=1:1:y+3  
  64.                 distant=sqrt((x-m)^2+(y-n)^2);  
  65.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  66.             end  
  67.         elseif (y>=8)      %结点靠右边  
  68.             for n=y-3:1:10  
  69.                 distant=sqrt((x-m)^2+(y-n)^2);  
  70.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  71.             end  
  72.         else  
  73.             for n=y-3:1:y+3  
  74.                 distant=sqrt((x-m)^2+(y-n)^2);  
  75.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  76.             end  
  77.         end  
  78.     end  
  79. end  
  80. end  

SOM_neighb2.m

[html]  view plain  copy
  1. function [weight]=SOM_neighb2(weight,train_data_active,j_min,delta,alpha)  
  2.   
  3. %% 权值调整幅度分布  
  4. %                          -0.2  
  5. %                           0.2  
  6. %                           0.6  
  7. %        -0.2   0.2   0.6    1    0.6   0.2   -0.2  
  8. %                           0.6  
  9. %                           0.2  
  10. %                          -0.2  
  11. % 单位距离转化比例为0.4  
  12. %% 坐标转换  
  13. [x,y]=line_to_array(j_min);  
  14. % 将1*70向量中的坐标转化为7*10矩阵中的坐标  
  15. %    1   8    ···  
  16. %    7   14   ···  
  17.   
  18. %% 权值调整过程  
  19. %结点靠上边情况  
  20. if (x<=2)  
  21.     for m=1:1:x+2  
  22.         if (y<=2)          %结点靠左边  
  23.             for n=1:1:y+2  
  24.                 distant=sqrt((x-m)^2+(y-n)^2);  
  25.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  26.             end  
  27.         elseif (y>=9)      %结点靠右边  
  28.             for n=y-2:1:10  
  29.                 distant=sqrt((x-m)^2+(y-n)^2);  
  30.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  31.             end  
  32.         else  
  33.             for n=y-2:1:y+2  
  34.                 distant=sqrt((x-m)^2+(y-n)^2);  
  35.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  36.             end  
  37.         end  
  38.     end  
  39.     %结点靠下边情况  
  40. elseif (x>=6)  
  41.     for m=x-2:1:7  
  42.       if (y<=2)          %结点靠左边  
  43.             for n=1:1:y+2  
  44.                 distant=sqrt((x-m)^2+(y-n)^2);  
  45.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  46.             end  
  47.         elseif (y>=9)      %结点靠右边  
  48.             for n=y-2:1:10  
  49.                 distant=sqrt((x-m)^2+(y-n)^2);  
  50.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  51.             end  
  52.         else  
  53.             for n=y-2:1:y+2  
  54.                 distant=sqrt((x-m)^2+(y-n)^2);  
  55.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  56.             end  
  57.       end  
  58.     end  
  59.     %结点正好在中间  
  60. else  
  61.     for m=x-2:1:x+2  
  62.          if (y<=2)          %结点靠左边  
  63.             for n=1:1:y+2  
  64.                 distant=sqrt((x-m)^2+(y-n)^2);  
  65.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  66.             end  
  67.         elseif (y>=9)      %结点靠右边  
  68.             for n=y-2:1:10  
  69.                 distant=sqrt((x-m)^2+(y-n)^2);  
  70.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  71.             end  
  72.         else  
  73.             for n=y-2:1:y+2  
  74.                 distant=sqrt((x-m)^2+(y-n)^2);  
  75.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  76.             end  
  77.         end  
  78.     end  
  79. end  
  80. end  

SOM_neighb1.m

[html]  view plain  copy
  1. function [weight]=SOM_neighb1(weight,train_data_active,j_min,delta,alpha)  
  2.   
  3. %% 权值调整幅度分布  
  4. %                          -0.2  
  5. %                           0.2  
  6. %                           0.6  
  7. %        -0.2   0.2   0.6    1    0.6   0.2   -0.2  
  8. %                           0.6  
  9. %                           0.2  
  10. %                          -0.2  
  11. % 单位距离转化比例为0.4  
  12. %% 坐标转换  
  13. [x,y]=line_to_array(j_min);  
  14. % 将1*70向量中的坐标转化为7*10矩阵中的坐标  
  15. %    1   8    ···  
  16. %    7   14   ···  
  17.   
  18. %% 权值调整过程  
  19. %结点靠上边情况  
  20. if (x<=1)  
  21.     for m=1:1:x+1  
  22.         if (y<=1)          %结点靠左边  
  23.             for n=1:1:y+3  
  24.                 distant=sqrt((x-m)^2+(y-n)^2);  
  25.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  26.             end  
  27.         elseif (y>=10)      %结点靠右边  
  28.             for n=y-1:1:10  
  29.                 distant=sqrt((x-m)^2+(y-n)^2);  
  30. weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  31.             end  
  32.         else  
  33.             for n=y-1:1:y+1  
  34.                 distant=sqrt((x-m)^2+(y-n)^2);  
  35.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  36.             end  
  37.         end  
  38.     end  
  39.     %结点靠下边情况  
  40. elseif (x>=7)  
  41.     for m=x-3:1:7  
  42.         if (y<=1)          %结点靠左边  
  43.             for n=1:1:y+3  
  44.                 distant=sqrt((x-m)^2+(y-n)^2);  
  45.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  46.             end  
  47.         elseif (y>=10)      %结点靠右边  
  48.             for n=y-1:1:10  
  49.                 distant=sqrt((x-m)^2+(y-n)^2);  
  50. weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  51.             end  
  52.         else  
  53.             for n=y-1:1:y+1  
  54.                 distant=sqrt((x-m)^2+(y-n)^2);  
  55.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  56.             end  
  57.         end  
  58.     end  
  59.     %结点正好在中间  
  60. else  
  61.     for m=x-1:1:x+1  
  62.         if (y<=1)          %结点靠左边  
  63.             for n=1:1:y+3  
  64.                 distant=sqrt((x-m)^2+(y-n)^2);  
  65.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  66.             end  
  67.         elseif (y>=10)      %结点靠右边  
  68.             for n=y-1:1:10  
  69.                 distant=sqrt((x-m)^2+(y-n)^2);  
  70. weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  71.             end  
  72.         else  
  73.             for n=y-1:1:y+1  
  74.                 distant=sqrt((x-m)^2+(y-n)^2);  
  75.                 weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);  
  76.             end  
  77.         end  
  78.     end  
  79. end  
  80. end  

SOM_neighb0.m

[html]  view plain  copy
  1. function [weight]=SOM_neighb0(weight,train_data_active,j_min,alpha)  
  2. weight(:,j_min)=weight(:,j_min)+alpha*(weight(:,j_min)-train_data_active);  
  3. end  

line_to_array.m

[html]  view plain  copy
  1. function [x,y]=line_to_array(j_min)  
  2. % 将1*70向量中的坐标转化为7*10矩阵中的坐标  
  3. %    1   8    ···  
  4. %    7   14   ···  
  5. y=ceil(j_min/7);  
  6. x=rem(j_min,7);  
  7. end  


四、结果显示

不同初始条件的结果图


你可能感兴趣的:(神经网络)