基于BP神经网络的数据分类
BP(Back
Propagation)网络是1986年由Rumelhart和McCelland为首的科学家小组提出,是一种按误差逆传播算法训练的多层前馈网络,是目前应用最广泛的神经网络模型之一。BP网络能学习和存贮大量的输入-输出模式映射关系,而无需事前揭示描述这种映射关系的数学方程。它的学习规则是使用最速下降法,通过反向传播来不断调整网络的权值和阈值,使网络的误差平方和最小。BP神经网络模型拓扑结构包括输入层(input)、隐层(hide layer)和输出层(output
layer)。
1 传统的BP算法简述
BP算法是一种有监督式的学习算法,其主要思想是:输入学习样本,使用反向传播算法对网络的权值和偏差进行反复的调整训练,使输出的向量与期望向量尽可能地接近,当网络输出层的误差平方和小于指定的误差时训练完成,保存网络的权值和偏差。具体步骤如下:
(1)初始化,随机给定各连接权及阀值。
(2)由给定的输入输出模式对计算隐层、输出层各单元输出
(3)计算新的连接权及阀值,计算公式如下:
(4)选取下一个输入模式对返回第2步反复训练直到网络设输出误差达到要求结束训练。
传统的BP算法,实质上是把一组样本输入/输出问题转化为一个非线性优化问题,并通过负梯度下降算法,利用迭代运算求解权值问题的一种学习方法,但其收敛速度慢且容易陷入局部极小,为此提出了一种新的算法,即高斯消元法。
2 改进的BP网络算法
2.1
改进算法概述
此前有人提出:任意选定一组自由权,通过对传递函数建立线性方程组,解得待求权。本文在此基础上将给定的目标输出直接作为线性方程等式代数和来建立线性方程组,不再通过对传递函数求逆来计算神经元的净输出,简化了运算步骤。没有采用误差反馈原理,因此用此法训练出来的神经网络结果与传统算法是等效的。其基本思想是:由所给的输入、输出模式对通过作用于神经网络来建立线性方程组,运用高斯消元法解线性方程组来求得未知权值,而未采用传统BP网络的非线性函数误差反馈寻优的思想。
2.2
改进算法的具体步骤
对给定的样本模式对,随机选定一组自由权,作为输出层和隐含层之间固定权值,通过传递函数计算隐层的实际输出,再将输出层与隐层间的权值作为待求量,直接将目标输出作为等式的右边建立方程组来求解。
(1)随机给定隐层和输入层间神经元的初始权值。
(2)由给定的样本输入计算出隐层的实际输出。
(3)计算输出层与隐层间的权值。以输出层的第r个神经元为对象,由给定的输出目标值作为等式的多项式值建立方程。
(4)重复第三步就可以求出输出层m个神经元的权值,以求的输出层的权矩阵加上随机固定的隐层与输入层的权值就等于神经网络最后训练的权矩阵。
3 计算机运算实例
%% 清空环境变量
clc
clear
%% 训练数据预测数据
data=importdata('test.txt');
%从1到768间随机排序
k=rand(1,768);
[m,n]=sort(k);
%输入输出数据
input=data(:,1:8);
output =data(:,9);
%随机提取500个样本为训练样本,268个样本为预测样本
input_train=input(n(1:500),:)';
output_train=output(n(1:500),:)';
input_test=input(n(501:768),:)';
output_test=output(n(501:768),:)';
%输入数据归一化
[inputn,inputps]=mapminmax(input_train);
%% BP网络训练
% %初始化网络结构
net=newff(inputn,output_train,10);
net.trainParam.epochs=1000;
net.trainParam.lr=0.1;
net.trainParam.goal=0.0000004;
%% 网络训练
net=train(net,inputn,output_train);
%% BP网络预测
%预测数据归一化
inputn_test=mapminmax('apply',input_test,inputps);
%网络预测输出
BPoutput=sim(net,inputn_test);
%% 结果分析
%根据网络输出找出数据属于哪类
BPoutput(find(BPoutput<0.5))=0;
BPoutput(find(BPoutput>=0.5))=1;
%% 结果分析
%画出预测种类和实际种类的分类图
figure(1)
plot(BPoutput,'og')
hold on
plot(output_test,'r*');
legend('预测类别','输出类别')
title('BP网络预测分类与实际类别比对','fontsize',12)
ylabel('类别标签','fontsize',12)
xlabel('样本数目','fontsize',12)
ylim([-0.5 1.5])
%预测正确率
rightnumber=0;
for i=1:size(output_test,2)
if
BPoutput(i)==output_test(i)
rightnumber=rightnumber+1;
end
end
rightratio=rightnumber/size(output_test,2)*100;
sprintf('测试准确率=%0.2f',rightratio)