学习地址
BP算法基本原理:
利用输出后的误差来估计输出层的直接前导层的误差,再用这个误差估计更前一层的误差,如此一层一层的反传下去,就获得了所有其他各层的误差估计。
一个三层BP网络:
激活函数:必须处处可导(一般都使用S型函数)
使用S型激活函数时,BP网络输入与输出关系如下:
输入:
n e t = x 1 w 1 + x 2 w 2 + . . . + x n w n net=x_1w_1+x_2w_2+...+x_nw_n net=x1w1+x2w2+...+xnwn
输出:
y = f ( n e t ) = 1 1 + e − n e t y=f(net)=\frac{1}{1+e^{-net}} y=f(net)=1+e−net1
输出的导数:
f ′ ( n e t ) = 1 1 − e − n e t − 1 ( 1 − e − n e t ) 2 = y ( 1 − y ) f'(net)=\frac{1}{1-e^{-net}}-\frac{1}{(1-e^{-net})^2}=y(1-y) f′(net)=1−e−net1−(1−e−net)21=y(1−y)
对神经网络进行训练,我们应该尽量将net的值尽量控制在收敛比较快的范围内。
(今天这个程序是为了复用性而强行拆解获得的)
package xjx;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
import weka.core.Instances;
public abstract class GeneralAnn {
//数据集
Instances dataset;
//层数,它是根据节点而不是边计算的
int numLayers;
//每个层的节点数,例如,[3,4,6,2]意味着有3个输入节点(条件属性),2个分别具有4和6个节点的隐藏层,以及2个类值(二进制分类)。
int[] layerNumNodes;
//动量系数
public double mobp;
//学习率
public double learningRate;
//用于随机数生成
Random random = new Random();
/**
********************
* 第一个构造器
*
* @param paraFilename
* arff文件名
* @param paraLayerNumNodes
* 每一层的结点数(可能不同)
* @param paraLearningRate
* 学习率
* @param paraMobp
* 动量系数
********************
*/
public GeneralAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,
double paraMobp) {
// Step 1. 读取数据.
try {
FileReader tempReader = new FileReader(paraFilename);
dataset = new Instances(tempReader);
// 最后一个属性是决定类的
dataset.setClassIndex(dataset.numAttributes() - 1);
tempReader.close();
} catch (Exception ee) {
System.out.println("Error occurred while trying to read \'" + paraFilename
+ "\' in GeneralAnn constructor.\r\n" + ee);
System.exit(0);
}
// Step 2. 接受参数
layerNumNodes = paraLayerNumNodes;
numLayers = layerNumNodes.length;//层数
// 调整
layerNumNodes[0] = dataset.numAttributes() - 1;
layerNumNodes[numLayers - 1] = dataset.numClasses();
learningRate = paraLearningRate;
mobp = paraMobp;
}
/**
********************
* 向前预测
*
* @param paraInput
* 一个实例的输入数据
* @return 输出端的数据
********************
*/
public abstract double[] forward(double[] paraInput);
/**
********************
*反向传播
*
* @param paraTarget
* 对于三类数据,[0,0,1],[0,1,0]或[1,0,0]
*
********************
*/
public abstract void backPropagation(double[] paraTarget);
/**
********************
* 使用数据集进行训练
********************
*/
public void train() {
double[] tempInput = new double[dataset.numAttributes() - 1];
double[] tempTarget = new double[dataset.numClasses()];
for (int i = 0; i < dataset.numInstances(); i++) {
// 填充数据
for (int j = 0; j < tempInput.length; j++) {
tempInput[j] = dataset.instance(i).value(j);
}
//填充类标签
Arrays.fill(tempTarget, 0);
tempTarget[(int) dataset.instance(i).classValue()] = 1;
// 训练实例
forward(tempInput);
backPropagation(tempTarget);
}
}
/**
********************
* 获取数组最大值对应的索引
*
* @return 索引.
********************
*/
public static int argmax(double[] paraArray) {
int resultIndex = -1;
double tempMax = -1e10;
for (int i = 0; i < paraArray.length; i++) {
if (tempMax < paraArray[i]) {
tempMax = paraArray[i];
resultIndex = i;
}
}
return resultIndex;
}
/**
********************
* 使用数据集测试.
*
* @return 预测.
********************
*/
public double test() {
double[] tempInput = new double[dataset.numAttributes() - 1];
double tempNumCorrect = 0;
double[] tempPrediction;
int tempPredictedClass = -1;
for (int i = 0; i < dataset.numInstances(); i++) {
// 填充数据
for (int j = 0; j < tempInput.length; j++) {
tempInput[j] = dataset.instance(i).value(j);
}
// 训练实例
tempPrediction = forward(tempInput);
//System.out.println("prediction: " + Arrays.toString(tempPrediction));
tempPredictedClass = argmax(tempPrediction);
if (tempPredictedClass == (int) dataset.instance(i).classValue()) {
tempNumCorrect++;
}
}
System.out.println("Correct: " + tempNumCorrect + " out of " + dataset.numInstances());
return tempNumCorrect / dataset.numInstances();
}
}
1.layerNumNodes 表示网络基本结构. 如: [3, 4, 6, 2] 表示:
a) 输入端口有 3 个,即数据有 3 个条件属性. 如果与实际数据不符, 代码会自动纠正, 见 GeneralAnn.java 81 行.
b) 输出端口有 2 个, 即数据的决策类别数为 2. 如果与实际数据不符, 代码会自动纠正, 见 GeneralAnn.java 82 行. 对于分类问题, 数据是哪个类别, 对应于输出值最大的端口.
c) 有两个中间层, 分别为 4 个和 6 个节点.
2.layerNodeValues 表示各网络节点的值. 如上例, 网络的节点有 4 层, 即 layerNodeValues.length 为 4. 总结点数为 3 + 4 + 6 + 2 − 15 3 + 4 + 6 + 2 - 153+4+6+2−15 个, 即 layerNodeValues[0].length = 3, layerNodeValues[1].length = 4, layerNodeValues[2].length = 6, layerNodeValues[3].length = 2. Java 支持这种不规则的矩阵 (不同行的列数不同), 因为二维矩阵被当作一维向量的一维向量.
3.layerNodeErrors 表示各网络节点上的误差. 该数组大小于 layerNodeValues 一致.
4.edgeWeights 表示各条边的权重. 由于两层之间的边为多对多关系 (二维数组), 多个层的边就成了三维数组. 例如, 上面例子的第 0 层就应该有 ( 3 + 1 ) × 4 = 16 (3+1) \times 4 = 16(3+1)×4=16 条边, 这里 + 1 +1+1 表示有偏移量 offset. 总共的层数为 4 − 1 = 3 4 - 1 = 34−1=3, 即边的层数要比节点层数少 1. 这也是写程序过程中非常容易出错的地方.
5.edgeWeightsDelta 与 edgeWeights 具有相同大小, 它辅助后者进行调整.
下面是核心代码.
package xjx;
public class SimpleAnn extends GeneralAnn{
/**
* 在转发过程中更改的每个节点的值。第一个维度代表层,第二个维度代表节点。
*/
public double[][] layerNodeValues;
/**
* 在反向传播过程中更改的每个节点上的错误。第一个维度代表层,第二个维度代表节点。
*/
public double[][] layerNodeErrors;
/**
* 边的权重。第一个维度代表层,第二个维度代表层的节点索引,第三个维度代表下一层的节点索引。
*/
public double[][][] edgeWeights;
/**
*边权重的变化。它的数组大小与边权重相同。
*/
public double[][][] edgeWeightsDelta;
/**
********************
* The first constructor.
*
* @param paraFilename
* arff文件名.
* @param paraLayerNumNodes
* 每层的节点数(可能不同).
* @param paraLearningRate
* 学习率.
* @param paraMobp
* 动量系数.
********************
*/
public SimpleAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp) {
super(paraFilename, paraLayerNumNodes, paraLearningRate, paraMobp);
// Step 1. 跨层初始化
layerNodeValues = new double[numLayers][];
layerNodeErrors = new double[numLayers][];
edgeWeights = new double[numLayers - 1][][];
edgeWeightsDelta = new double[numLayers - 1][][];
// Step 2. 内层初始化
for (int l = 0; l < numLayers; l++) {
layerNodeValues[l] = new double[layerNumNodes[l]];
layerNodeErrors[l] = new double[layerNumNodes[l]];
// 少了一层,因为每一条边穿过两层。
if (l + 1 == numLayers) {
break;
}
// 在layerNumNodes[l]+1中,最后一个是为偏移量保留的。
edgeWeights[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];
edgeWeightsDelta[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];
for (int j = 0; j < layerNumNodes[l] + 1; j++) {
for (int i = 0; i < layerNumNodes[l + 1]; i++) {
// 初始化权重
edgeWeights[l][j][i] = random.nextDouble();
}
}
}
}
/**
********************
* 向前预测.
*
* @param paraInput
* The input data of one instance.
* @return The data at the output end.
********************
*/
public double[] forward(double[] paraInput) {
//初始化输入层
for (int i = 0; i < layerNodeValues[0].length; i++) {
layerNodeValues[0][i] = paraInput[i];
}
// 计算每层的节点值
double z;
for (int l = 1; l < numLayers; l++) {
for (int j = 0; j < layerNodeValues[l].length; j++) {
// 根据偏移量初始化,偏移量总是+1
z = edgeWeights[l - 1][layerNodeValues[l - 1].length][j];
// 此节点所有边上的加权和。
for (int i = 0; i < layerNodeValues[l - 1].length; i++) {
z += edgeWeights[l - 1][i][j] * layerNodeValues[l - 1][i];
}
layerNodeValues[l][j] = 1 / (1 + Math.exp(-z));
}
}
return layerNodeValues[numLayers - 1];
}
/**
********************
* 反向传播和改变边缘权重。
*
* @param paraTarget
* For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
********************
*/
public void backPropagation(double[] paraTarget) {
// Step 1. 初始化输出层错误。
int l = numLayers - 1;
for (int j = 0; j < layerNodeErrors[l].length; j++) {
layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j])
* (paraTarget[j] - layerNodeValues[l][j]);
}
// Step 2. l=0时反向传播
while (l > 0) {
l--;
// l层
for (int j = 0; j < layerNumNodes[l]; j++) {
double z = 0.0;
// 对于下一层的每个节点。
for (int i = 0; i < layerNumNodes[l + 1]; i++) {
if (l > 0) {
z += layerNodeErrors[l + 1][i] * edgeWeights[l][j][i];
}
// 重量调整
edgeWeightsDelta[l][j][i] = mobp * edgeWeightsDelta[l][j][i]
+ learningRate * layerNodeErrors[l + 1][i] * layerNodeValues[l][j];
edgeWeights[l][j][i] += edgeWeightsDelta[l][j][i];
if (j == layerNumNodes[l] - 1) {
// 偏移部分的重量调整。
edgeWeightsDelta[l][j + 1][i] = mobp * edgeWeightsDelta[l][j + 1][i]
+ learningRate * layerNodeErrors[l + 1][i];
edgeWeights[l][j + 1][i] += edgeWeightsDelta[l][j + 1][i];
}
}
//记录错误
layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * z;
}
}
}
/**
********************
* 测试算法.
********************
*/
public static void main(String[] args) {
int[] tempLayerNodes = { 4, 8, 8, 3 };
SimpleAnn tempNetwork = new SimpleAnn("D:/data/iris.arff", tempLayerNodes, 0.01, 0.6);
for (int round = 0; round < 5000; round++) {
tempNetwork.train();
}
double tempAccuray = tempNetwork.test();
System.out.println("The accuracy is: " + tempAccuray);
}
}
1.Forward 就是利用当前网络对一条数据进行预测的过程.
2.BackPropagation 就是根据误差进行网络权重调节的过程.
3.训练的时候需要前向与后向, 测试的时候只需要前向.
4.这里只实现了 sigmoid 激活函数, 反向传播时的导数与正向传播时的激活函数相对应. 如果要换激活函数, 需要两个地方同时换.
/**
********************
* 向前预测.
*
* @param paraInput
* The input data of one instance.
* @return The data at the output end.
********************
*/
public double[] forward(double[] paraInput) {
//初始化输入层
for (int i = 0; i < layerNodeValues[0].length; i++) {
layerNodeValues[0][i] = paraInput[i];
}
// 计算每层的节点值
double z;
for (int l = 1; l < numLayers; l++) {
for (int j = 0; j < layerNodeValues[l].length; j++) {
// 根据偏移量初始化,偏移量总是+1
z = edgeWeights[l - 1][layerNodeValues[l - 1].length][j];
// 此节点所有边上的加权和。
for (int i = 0; i < layerNodeValues[l - 1].length; i++) {
z += edgeWeights[l - 1][i][j] * layerNodeValues[l - 1][i];
}
layerNodeValues[l][j] = 1 / (1 + Math.exp(-z));
}
}
return layerNodeValues[numLayers - 1];
}
/**
********************
* 反向传播和改变边缘权重。
*
* @param paraTarget
* For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
********************
*/
public void backPropagation(double[] paraTarget) {
// Step 1. 初始化输出层错误。
int l = numLayers - 1;
for (int j = 0; j < layerNodeErrors[l].length; j++) {
layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j])
* (paraTarget[j] - layerNodeValues[l][j]);
}
// Step 2. l=0时反向传播
while (l > 0) {
l--;
// l层
for (int j = 0; j < layerNumNodes[l]; j++) {
double z = 0.0;
// 对于下一层的每个节点。
for (int i = 0; i < layerNumNodes[l + 1]; i++) {
if (l > 0) {
z += layerNodeErrors[l + 1][i] * edgeWeights[l][j][i];
}
// 重量调整
edgeWeightsDelta[l][j][i] = mobp * edgeWeightsDelta[l][j][i]
+ learningRate * layerNodeErrors[l + 1][i] * layerNodeValues[l][j];
edgeWeights[l][j][i] += edgeWeightsDelta[l][j][i];
if (j == layerNumNodes[l] - 1) {
// 偏移部分的重量调整。
edgeWeightsDelta[l][j + 1][i] = mobp * edgeWeightsDelta[l][j + 1][i]
+ learningRate * layerNodeErrors[l + 1][i];
edgeWeights[l][j + 1][i] += edgeWeightsDelta[l][j + 1][i];
}
}
//记录错误
layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * z;
}
}
}
激活函数是神经网络的核心。
1.激活与求导是一个, 前者用于 forward, 后者用于 back-propagation.
2.有很多的激活函数, 它们的设计有相应准则, 如分段可导.
ReLU函数又称为修正线性单元(Rectified Linear Unit),是一种分段线性函数,其弥补了sigmoid函数以及tanh函数的梯度消失问题。ReLU函数的公式以及图形如下:
g ( z ) = { z , if z>0 0 , if z<0 g(z)=\begin{cases} z,& \text {if z>0 }\\ 0,& \text {if z<0}\end{cases} g(z)={z,0,if z>0 if z<0
对于ReLU函数的求导为:
g ′ ( z ) = { 1 , if z>0 0 , if z<0 g'(z)=\begin{cases} 1,& \text {if z>0 }\\ 0,& \text {if z<0}\end{cases} g′(z)={1,0,if z>0 if z<0
ReLU函数的优点:
(1)在输入为正数的时候(对于大多数输入 z 空间来说),不存在梯度消失问题。
(2) 计算速度要快很多。ReLU函数只有线性关系,不管是前向传播还是反向传播,都比sigmod和tanh要快很多。(sigmod和tanh要计算指数,计算速度会比较慢)
ReLU函数的缺点:
(1)当输入为负时,梯度为0,会产生梯度消失问题。
代码:
package xjx;
public class Activator {
/**
* Arc tan.
*/
public final char ARC_TAN = 'a';
/**
* Elu.
*/
public final char ELU = 'e';
/**
* Gelu.
*/
public final char GELU = 'g';
/**
* Hard logistic.
*/
public final char HARD_LOGISTIC = 'h';
/**
* Identity.
*/
public final char IDENTITY = 'i';
/**
* Leaky relu, also known as parametric relu.
*/
public final char LEAKY_RELU = 'l';
/**
* Relu.
*/
public final char RELU = 'r';
/**
* Soft sign.
*/
public final char SOFT_SIGN = 'o';
/**
* Sigmoid.
*/
public final char SIGMOID = 's';
/**
* Tanh.
*/
public final char TANH = 't';
/**
* Soft plus.
*/
public final char SOFT_PLUS = 'u';
/**
* Swish.
*/
public final char SWISH = 'w';
/**
* The activator.
*/
private char activator;
/**
* Alpha for elu.
*/
double alpha;
/**
* Beta for leaky relu.
*/
double beta;
/**
* Gamma for leaky relu.
*/
double gamma;
/**
*********************
* The first constructor.
*
* @param paraActivator
* The activator.
*********************
*/
public Activator(char paraActivator) {
activator = paraActivator;
}
/**
*********************
* Setter.
*********************
*/
public void setActivator(char paraActivator) {
activator = paraActivator;
}
/**
*********************
* Getter.
*********************
*/
public char getActivator() {
return activator;
}
/**
*********************
* Setter.
*********************
*/
void setAlpha(double paraAlpha) {
alpha = paraAlpha;
}
/**
*********************
* Setter.
*********************
*/
void setBeta(double paraBeta) {
beta = paraBeta;
}
/**
*********************
* Setter.
*********************
*/
void setGamma(double paraGamma) {
gamma = paraGamma;
}
/**
*********************
* Activate according to the activation function.
*********************
*/
public double activate(double paraValue) {
double resultValue = 0;
switch (activator) {
case ARC_TAN:
resultValue = Math.atan(paraValue);
break;
case ELU:
if (paraValue >= 0) {
resultValue = paraValue;
} else {
resultValue = alpha * (Math.exp(paraValue) - 1);
}
break;
// case GELU:
// resultValue = ?;
// break;
// case HARD_LOGISTIC:
// resultValue = ?;
// break;
case IDENTITY:
resultValue = paraValue;
break;
case LEAKY_RELU:
if (paraValue >= 0) {
resultValue = paraValue;
} else {
resultValue = alpha * paraValue;
}
break;
case SOFT_SIGN:
if (paraValue >= 0) {
resultValue = paraValue / (1 + paraValue);
} else {
resultValue = paraValue / (1 - paraValue);
}
break;
case SOFT_PLUS:
resultValue = Math.log(1 + Math.exp(paraValue));
break;
case RELU:
if (paraValue >= 0) {
resultValue = paraValue;
} else {
resultValue = 0;
}
break;
case SIGMOID:
resultValue = 1 / (1 + Math.exp(-paraValue));
break;
case TANH:
resultValue = 2 / (1 + Math.exp(-2 * paraValue)) - 1;
break;
// case SWISH:
// resultValue = ?;
// break;
default:
System.out.println("Unsupported activator: " + activator);
System.exit(0);
}
return resultValue;
}
/**
*********************
* 根据激活函数导出
*
* @param paraValue
* The original value x.
* @param paraActivatedValue
* f(x).
*********************
*/
public double derive(double paraValue, double paraActivatedValue) {
double resultValue = 0;
switch (activator) {
case ARC_TAN:
resultValue = 1 / (paraValue * paraValue + 1);
break;
case ELU:
if (paraValue >= 0) {
resultValue = 1;
} else {
resultValue = alpha * (Math.exp(paraValue) - 1) + alpha;
} // Of if
break;
// case GELU:
// resultValue = ?;
// break;
// case HARD_LOGISTIC:
// resultValue = ?;
// break;
case IDENTITY:
resultValue = 1;
break;
case LEAKY_RELU:
if (paraValue >= 0) {
resultValue = 1;
} else {
resultValue = alpha;
}
break;
case SOFT_SIGN:
if (paraValue >= 0) {
resultValue = 1 / (1 + paraValue) / (1 + paraValue);
} else {
resultValue = 1 / (1 - paraValue) / (1 - paraValue);
}
break;
case SOFT_PLUS:
resultValue = 1 / (1 + Math.exp(-paraValue));
break;
case RELU: // 更新
if (paraValue >= 0) {
resultValue = 1;
} else {
resultValue = 0;
}
break;
case SIGMOID: // 更新
resultValue = paraActivatedValue * (1 - paraActivatedValue);
break;
case TANH: // 更新
resultValue = 1 - paraActivatedValue * paraActivatedValue;
break;
// case SWISH:
// resultValue = ?;
// break;
default:
System.out.println("Unsupported activator: " + activator);
System.exit(0);
}
return resultValue;
}
/**
*********************
* 重写对象中声明的方法。
*********************
*/
public String toString() {
String resultString = "Activator with function '" + activator + "'";
resultString += "\r\n alpha = " + alpha + ", beta = " + beta + ", gamma = " + gamma;
return resultString;
}
/**
********************
* 测试
********************
*/
public static void main(String[] args) {
Activator tempActivator = new Activator('s');
double tempValue = 0.6;
double tempNewValue;
tempNewValue = tempActivator.activate(tempValue);
System.out.println("After activation: " + tempNewValue);
tempNewValue = tempActivator.derive(tempValue, tempNewValue);
System.out.println("After derive: " + tempNewValue);
}
}
1.仅实现单层 ANN.
2.正向计算输出, 反向计算误差并调整权值.
输出结果:
Activator: Activator with function 's'
alpha = 0.0, beta = 0.0, gamma = 0.0
weights = [[0.6084959193944588, 0.4221456104753831, 0.6183449276687938], [0.7704253816634953, 0.6636288072285302, 0.8794802183018241], [0.17489521516629425, 0.004899930192123647, 0.13474601385167118]]
Forward, the output is: [0.9794693622124504, 0.9561257044911313, 0.9862247642459836]
Back propagation, the error is: [0.03720166959927697, 0.053575467276064444]
代码:
package xjx;
import java.util.Arrays;
import java.util.Random;
public class AnnLayer {
/**
* 输入个数
*/
int numInput;
/**
* 输出个数
*/
int numOutput;
/**
* 学习率
*/
double learningRate;
/**
* 动量系数
*/
double mobp;
/**
* 权值矩阵
*/
double[][] weights, deltaWeights;
double[] offset, deltaOffset, errors;
/**
* 输入
*/
double[] input;
/**
* 输出
*/
double[] output;
/**
* 激活后的输出
*/
double[] activatedOutput;
/**
* 输入
*/
Activator activator;
Random random = new Random();
/**
*********************
* The first constructor.
*
* @param paraActivator
* The activator.
*********************
*/
public AnnLayer(int paraNumInput, int paraNumOutput, char paraActivator, double paraLearningRate, double paraMobp) {
numInput = paraNumInput;
numOutput = paraNumOutput;
learningRate = paraLearningRate;
mobp = paraMobp;
weights = new double[numInput + 1][numOutput];
deltaWeights = new double[numInput + 1][numOutput];
for (int i = 0; i < numInput + 1; i++) {
for (int j = 0; j < numOutput; j++) {
weights[i][j] = random.nextDouble();
}
}
offset = new double[numOutput];
deltaOffset = new double[numOutput];
errors = new double[numInput];
input = new double[numInput];
output = new double[numOutput];
activatedOutput = new double[numOutput];
activator = new Activator(paraActivator);
}
/**
********************
* 设置激活器的参数
*
* @param paraAlpha
* Alpha. 仅对某些类型有效
* @param paraBeta
* Beta.
* @param paraAlpha
* Alpha.
********************
*/
public void setParameters(double paraAlpha, double paraBeta, double paraGamma) {
activator.setAlpha(paraAlpha);
activator.setBeta(paraBeta);
activator.setGamma(paraGamma);
}
/**
********************
* 向前预测
*
* @param paraInput
* The input data of one instance.
* @return The data at the output end.
********************
*/
public double[] forward(double[] paraInput) {
//System.out.println("Ann layer forward " + Arrays.toString(paraInput));
// 复制数据
for (int i = 0; i < numInput; i++) {
input[i] = paraInput[i];
}
// 计算每个输出的加权和
for (int i = 0; i < numOutput; i++) {
output[i] = weights[numInput][i];
for (int j = 0; j < numInput; j++) {
output[i] += input[j] * weights[j][i];
}
activatedOutput[i] = activator.activate(output[i]);
}
return activatedOutput;
}
/**
********************
* 反向传播和改变边缘权重
*
* @param paraTarget
* For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
********************
*/
public double[] backPropagation(double[] paraErrors) {
//Step 1. 调整错误
for (int i = 0; i < paraErrors.length; i++) {
paraErrors[i] = activator.derive(output[i], activatedOutput[i]) * paraErrors[i];
}
//Step 2. 计算当前错误
for (int i = 0; i < numInput; i++) {
errors[i] = 0;
for (int j = 0; j < numOutput; j++) {
errors[i] += paraErrors[j] * weights[i][j];
deltaWeights[i][j] = mobp * deltaWeights[i][j] + learningRate * paraErrors[j] * input[i];
weights[i][j] += deltaWeights[i][j];
if (i == numInput - 1) {
// 偏移量调整
deltaOffset[j] = mobp * deltaOffset[j] + learningRate * paraErrors[j];
offset[j] += deltaOffset[j];
}
}
}
return errors;
}
/**
********************
* I am the last layer, set the errors.
*
* @param paraTarget
* For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
********************
*/
public double[] getLastLayerErrors(double[] paraTarget) {
double[] resultErrors = new double[numOutput];
for (int i = 0; i < numOutput; i++) {
resultErrors[i] = (paraTarget[i] - activatedOutput[i]);
}
return resultErrors;
}
/**
********************
* Show me.
********************
*/
public String toString() {
String resultString = "";
resultString += "Activator: " + activator;
resultString += "\r\n weights = " + Arrays.deepToString(weights);
return resultString;
}
/**
********************
* Unit test.
********************
*/
public static void unitTest() {
AnnLayer tempLayer = new AnnLayer(2, 3, 's', 0.01, 0.1);
double[] tempInput = { 1, 4 };
System.out.println(tempLayer);
double[] tempOutput = tempLayer.forward(tempInput);
System.out.println("Forward, the output is: " + Arrays.toString(tempOutput));
double[] tempError = tempLayer.backPropagation(tempOutput);
System.out.println("Back propagation, the error is: " + Arrays.toString(tempError));
}
/**
********************
* Test the algorithm.
********************
*/
public static void main(String[] args) {
unitTest();
}
}
测试结果:
Correct: 146.0 out of 150
The accuracy is: 0.9733333333333334
FullAnn ends.
代码:
package xjx;
import java.util.Arrays;
public class FullAnn extends GeneralAnn {
/**
* 层
*/
AnnLayer[] layers;
/**
********************
* The first constructor.
*
* @param paraFilename
* The arff filename.
* @param paraLayerNumNodes
* The number of nodes for each layer (may be different).
* @param paraLearningRate
* Learning rate.
* @param paraMobp
* Momentum coefficient.
* @param paraActivators The storing the activators of each layer.
********************
*/
public FullAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp, String paraActivators) {
super(paraFilename, paraLayerNumNodes, paraLearningRate, paraMobp);
// 初始化层.
layers = new AnnLayer[numLayers - 1];
for (int i = 0; i < layers.length; i++) {
layers[i] = new AnnLayer(layerNumNodes[i], layerNumNodes[i + 1], paraActivators.charAt(i), paraLearningRate, paraMobp);
}
}
/**
********************
* 向上预测
*
* @param paraInput
* The input data of one instance.
* @return The data at the output end.
********************
*/
public double[] forward(double[] paraInput) {
double[] resultArray = paraInput;
for(int i = 0; i < numLayers - 1; i ++) {
resultArray = layers[i].forward(resultArray);
}
return resultArray;
}
/**
********************
* 反向传播
*
* @param paraTarget
* For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
*
********************
*/
public void backPropagation(double[] paraTarget) {
double[] tempErrors = layers[numLayers - 2].getLastLayerErrors(paraTarget);
for (int i = numLayers - 2; i >= 0; i--) {
tempErrors = layers[i].backPropagation(tempErrors);
}
return;
}
/**
********************
* Show me.
********************
*/
public String toString() {
String resultString = "I am a full ANN with " + numLayers + " layers";
return resultString;
}
/**
********************
* 测试.
********************
*/
public static void main(String[] args) {
int[] tempLayerNodes = { 4, 8, 8, 3 };
FullAnn tempNetwork = new FullAnn("D:/data/iris.arff", tempLayerNodes, 0.01, 0.6, "sss");
for (int round = 0; round < 5000; round++) {
tempNetwork.train();
}
double tempAccuray = tempNetwork.test();
System.out.println("The accuracy is: " + tempAccuray);
System.out.println("FullAnn ends.");
}
}
先在eclipse官网选择对应版本的GUI插件:安装网址
在eclipse中安装:Help->Install New Software…
安装完重启编译器
然后新建项目,New→Project→WindowBuilder→SWT Designer→SWT/JFace Java Project,建立一个包,在建类的时候选择New→Other,选择WindowBuilder→Swing Designer→Application Window.类建好之后点击Design就可以进行可视化编辑了。
但是在引用java.awt.event时会报错,必须在modules.java里声明 requires java.desktop;就不会报错了
可在design里面设计:
代码说明:
ApplicationShowdown.java 仅用于退出图形用户界面 GUI.
只生成了一个静态的实例对象. 构造方法是 private 的, 不允许在该类之外 new. 这是一个有意思的小技巧.
package xjx;
import java.awt.event.*;
public class ApplicationShutdown implements WindowListener, ActionListener {
/**
* Only one instance.
*/
public static ApplicationShutdown applicationShutdown = new ApplicationShutdown();
/**
***************************
* This constructor is private such that the only instance is generated here.
***************************
*/
private ApplicationShutdown() {
}// Of ApplicationShutdown.
/**
***************************
* Shutdown the system
***************************
*/
public void windowClosing(WindowEvent comeInWindowEvent) {
System.exit(0);
}// Of windowClosing.
public void windowActivated(WindowEvent comeInWindowEvent) {
}// Of windowActivated.
public void windowClosed(WindowEvent comeInWindowEvent) {
}// Of windowClosed.
public void windowDeactivated(WindowEvent comeInWindowEvent) {
}// Of windowDeactivated.
public void windowDeiconified(WindowEvent comeInWindowEvent) {
}// Of windowDeiconified.
public void windowIconified(WindowEvent comeInWindowEvent) {
}// Of windowIconified.
public void windowOpened(WindowEvent comeInWindowEvent) {
}// Of windowOpened.
/**
*************************
*************************
*/
public void actionPerformed(ActionEvent ee) {
System.exit(0);
}// Of actionPerformed.
}// Of class ApplicationShutdown
DialogCloser.java 用于关闭窗口, 而不是整个的 GUI.
package xjx;
import java.awt.*;
import java.awt.event.*;
public class DialogCloser extends WindowAdapter implements ActionListener {
/**
* The dialog under control.
*/
private Dialog currentDialog;
/**
***************************
* The first constructor.
***************************
*/
public DialogCloser() {
super();
}// Of the first constructor
/**
***************************
* The second constructor.
*
* @param paraDialog
* the dialog under control
***************************
*/
public DialogCloser(Dialog paraDialog) {
currentDialog = paraDialog;
}// Of the second constructor
/**
***************************
* Close the dialog which clicking the cross at the up-right corner of the window.
*
* @param comeInWindowEvent
* From it we can obtain which window sent the message because X
* was used.
***************************
*/
public void windowClosing(WindowEvent paraWindowEvent) {
paraWindowEvent.getWindow().dispose();
}// Of windowClosing.
/**
***************************
* Close the dialog while pushing an "OK" or "Cancel" button.
*
* @param paraEvent
* Not considered.
***************************
*/
public void actionPerformed(ActionEvent paraEvent) {
currentDialog.dispose();
}// Of actionPerformed.
}// Of class DialogCloser
ErrorDialog.java 用于显示出错信息. 有了 GUI 我们可以不再使用 System.out.println.
package xjx;
import java.awt.*;
public class ErrorDialog extends Dialog {
/**
* Serial uid. Not quite useful.
*/
private static final long serialVersionUID = 124535235L;
/**
* The ONLY ErrorDialog.
*/
public static ErrorDialog errorDialog = new ErrorDialog();
/**
* The label containing the message to display.
*/
private TextArea messageTextArea;
/**
***************************
* Display an error dialog and respective error message. Like other dialogs,
* this constructor is private, such that users can use only one dialog,
* i.e., ErrorDialog.errorDialog to display message. This is helpful for
* saving space (only one dialog) since we may need "many" dialogs.
***************************
*/
private ErrorDialog() {
// This dialog is module.
super(GUICommon.mainFrame, "Error", true);
// Prepare for the dialog.
messageTextArea = new TextArea();
Button okButton = new Button("OK");
okButton.setSize(20, 10);
okButton.addActionListener(new DialogCloser(this));
Panel okPanel = new Panel();
okPanel.setLayout(new FlowLayout());
okPanel.add(okButton);
// Add TextArea and Button
setLayout(new BorderLayout());
add(BorderLayout.CENTER, messageTextArea);
add(BorderLayout.SOUTH, okPanel);
setLocation(200, 200);
setSize(500, 200);
addWindowListener(new DialogCloser());
setVisible(false);
}// Of constructor
/**
***************************
* set message.
*
* @param paramMessage
* the new message
***************************
*/
public void setMessageAndShow(String paramMessage) {
messageTextArea.setText(paramMessage);
setVisible(true);
}// Of setTitleAndMessage
}// Of class ErrorDialog
GUICommon.java 存储一些公用变量.
package xjx;
import java.awt.*;
import javax.swing.*;
public class GUICommon extends Object {
/**
* Only one main frame.
*/
public static Frame mainFrame = null;
/**
* Only one main pane.
*/
public static JTabbedPane mainPane = null;
/**
* For default project number.
*/
public static int currentProjectNumber = 0;
/**
* Default font.
*/
public static final Font MY_FONT = new Font("Times New Romans", Font.PLAIN, 12);
/**
* Default color
*/
public static final Color MY_COLOR = Color.lightGray;
/**
***************************
* Set the main frame. This can be done only once at the initialzing stage.
*
* @param paraFrame
* the main frame of the GUI.
* @throws Exception
* If the main frame is set more than once.
***************************
*/
public static void setFrame(Frame paraFrame) throws Exception {
if (mainFrame == null) {
mainFrame = paraFrame;
} else {
throw new Exception("The main frame can be set only ONCE!");
} // Of if
}// Of setFrame
/**
***************************
* Set the main pane. This can be done only once at the initialzing stage.
*
* @param paramPane
* the main pane of the GUI.
* @throws Exception
* If the main panel is set more than once.
***************************
*/
public static void setPane(JTabbedPane paramPane) throws Exception {
if (mainPane == null) {
mainPane = paramPane;
} else {
throw new Exception("The main panel can be set only ONCE!");
} // Of if
}// Of setPAne
}// Of class GUICommon
HelpDialog.java 显示帮助信息, 这样, 在主界面点击 Help 按钮时, 就会显示相关参数的说明. 其目的在于提高软件的易用性、可维护性.
package xjx;
import java.io.*;
import java.awt.*;
import java.awt.event.*;
public class HelpDialog extends Dialog implements ActionListener {
/**
* Serial uid. Not quite useful.
*/
private static final long serialVersionUID = 3869415040299264995L;
/**
***************************
* Display the help dialog.
*
* @param paraTitle
* the title of the dialog.
* @param paraFilename
* the help file.
***************************
*/
public HelpDialog(String paraTitle, String paraFilename) {
super(GUICommon.mainFrame, paraTitle, true);
setBackground(GUICommon.MY_COLOR);
TextArea displayArea = new TextArea("", 10, 10, TextArea.SCROLLBARS_VERTICAL_ONLY);
displayArea.setEditable(false);
String textToDisplay = "";
try {
RandomAccessFile helpFile = new RandomAccessFile(paraFilename, "r");
String tempLine = helpFile.readLine();
while (tempLine != null) {
textToDisplay = textToDisplay + tempLine + "\n";
tempLine = helpFile.readLine();
}
helpFile.close();
} catch (IOException ee) {
dispose();
ErrorDialog.errorDialog.setMessageAndShow(ee.toString());
}
// Use this if you need to display Chinese. Consult the author for this
// method.
// textToDisplay = SimpleTools.GB2312ToUNICODE(textToDisplay);
displayArea.setText(textToDisplay);
displayArea.setFont(new Font("Times New Romans", Font.PLAIN, 14));
Button okButton = new Button("OK");
okButton.setSize(20, 10);
okButton.addActionListener(new DialogCloser(this));
Panel okPanel = new Panel();
okPanel.setLayout(new FlowLayout());
okPanel.add(okButton);
// OK Button
setLayout(new BorderLayout());
add(BorderLayout.CENTER, displayArea);
add(BorderLayout.SOUTH, okPanel);
setLocation(120, 70);
setSize(500, 400);
addWindowListener(new DialogCloser());
setVisible(false);
}// Of constructor
/**
*************************
* Simply set it visible.
*************************
*/
public void actionPerformed(ActionEvent ee) {
setVisible(true);
}// Of actionPerformed.
}// Of class HelpDialog
DoubleField.java 用于接受实型值, 如果不能解释成实型值会报错. 这样可以把用户的低级错误扼杀在摇篮中.
package xjx_GUI;
import java.awt.*;
import java.awt.event.*;
public class DoubleField extends TextField implements FocusListener {
/**
* Serial uid. Not quite useful.
*/
private static final long serialVersionUID = 363634723L;
/**
* The value
*/
protected double doubleValue;
/**
***************************
* Give it default values.
***************************
*/
public DoubleField() {
this("5.13", 10);
}// Of the first constructor
/**
***************************
* Only specify the content.
*
* @param paraString
* The content of the field.
***************************
*/
public DoubleField(String paraString) {
this(paraString, 10);
}// Of the second constructor
/**
***************************
* Only specify the width.
*
* @param paraWidth
* The width of the field.
***************************
*/
public DoubleField(int paraWidth) {
this("5.13", paraWidth);
}// Of the third constructor
/**
***************************
* Specify the content and the width.
*
* @param paraString
* The content of the field.
* @param paraWidth
* The width of the field.
***************************
*/
public DoubleField(String paraString, int paraWidth) {
super(paraString, paraWidth);
addFocusListener(this);
}// Of the fourth constructor
/**
**********************************
* Implement FocusListenter.
*
* @param paraEvent
* The event is unimportant.
**********************************
*/
public void focusGained(FocusEvent paraEvent) {
}// Of focusGained
/**
**********************************
* Implement FocusListenter.
*
* @param paraEvent
* The event is unimportant.
**********************************
*/
public void focusLost(FocusEvent paraEvent) {
try {
doubleValue = Double.parseDouble(getText());
} catch (Exception ee) {
ErrorDialog.errorDialog
.setMessageAndShow("\"" + getText() + "\" Not a double. Please check.");
requestFocus();
} // Of try
}// Of focusLost
/**
**********************************
* Get the double value.
*
* @return the double value.
**********************************
*/
public double getValue() {
try {
doubleValue = Double.parseDouble(getText());
} catch (Exception ee) {
ErrorDialog.errorDialog
.setMessageAndShow("\"" + getText() + "\" Not a double. Please check.");
requestFocus();
} // Of try
return doubleValue;
}// Of getValue
}// Of class DoubleField
IntegeField.java 同理.
package xjx_GUI;
import java.awt.*;
import java.awt.event.*;
public class IntegeField extends TextField implements FocusListener {
/**
* Serial uid. Not quite useful.
*/
private static final long serialVersionUID = -2462338973265150779L;
/**
***************************
* Only specify the content.
***************************
*/
public IntegeField() {
this("513");
}// Of constructor
/**
***************************
* Specify the content and the width.
*
* @param paraString
* The default value of the content.
* @param paraWidth
* The width of the field.
***************************
*/
public IntegeField(String paraString, int paraWidth) {
super(paraString, paraWidth);
addFocusListener(this);
}// Of constructor
/**
***************************
* Only specify the content.
*
* @param paraString
* The given default string.
***************************
*/
public IntegeField(String paraString) {
super(paraString);
addFocusListener(this);
}// Of constructor
/**
***************************
* Only specify the width.
*
* @param paraWidth
* The width of the field.
***************************
*/
public IntegeField(int paraWidth) {
super(paraWidth);
setText("513");
addFocusListener(this);
}// Of constructor
/**
**********************************
* Implement FocusListenter.
*
* @param paraEvent
* The event is unimportant.
**********************************
*/
public void focusGained(FocusEvent paraEvent) {
}// Of focusGained
/**
**********************************
* Implement FocusListenter.
*
* @param paraEvent
* The event is unimportant.
**********************************
*/
public void focusLost(FocusEvent paraEvent) {
try {
Integer.parseInt(getText());
// System.out.println(tempInt);
} catch (Exception ee) {
ErrorDialog.errorDialog.setMessageAndShow("\"" + getText()
+ "\"Not an integer. Please check.");
requestFocus();
}
}// Of focusLost
/**
**********************************
* Get the int value. Show error message if the content is not an int.
*
* @return the int value.
**********************************
*/
public int getValue() {
int tempInt = 0;
try {
tempInt = Integer.parseInt(getText());
} catch (Exception ee) {
ErrorDialog.errorDialog.setMessageAndShow("\"" + getText()
+ "\" Not an int. Please check.");
requestFocus();
}
return tempInt;
}// Of getValue
}// Of class IntegerField
FilenameField.java 则需要借助于系统提供的 FileDialog.
package xjx_GUI;
import java.io.*;
import java.awt.*;
import java.awt.event.*;
public class FilenameField extends TextField implements ActionListener,
FocusListener {
/**
* Serial uid. Not quite useful.
*/
private static final long serialVersionUID = 4572287941606065298L;
/**
***************************
* No special initialization..
***************************
*/
public FilenameField() {
super();
setText("");
addFocusListener(this);
}// Of constructor
/**
***************************
* No special initialization.
*
* @param paraWidth
* The width of the .
***************************
*/
public FilenameField(int paraWidth) {
super(paraWidth);
setText("");
addFocusListener(this);
}// Of constructor
/**
***************************
* No special initialization.
*
* @param paraWidth
* The width of the .
* @param paraText
* The given initial text
***************************
*/
public FilenameField(int paraWidth, String paraText) {
super(paraWidth);
setText(paraText);
addFocusListener(this);
}// Of constructor
/**
***************************
* No special initialization.
*
* @param paraWidth
* The width of the .
* @param paraText
* The given initial text
***************************
*/
public FilenameField(String paraText, int paraWidth) {
super(paraWidth);
setText(paraText);
addFocusListener(this);
}// Of constructor
/**
**********************************
* Avoid setting null or empty string.
*
* @param paraText
* The given text.
**********************************
*/
public void setText(String paraText) {
if (paraText.trim().equals("")) {
super.setText("unspecified");
} else {
super.setText(paraText.replace('\\', '/'));
}//Of if
}// Of setText
/**
**********************************
* Implement ActionListenter.
*
* @param paraEvent
* The event is unimportant.
**********************************
*/
public void actionPerformed(ActionEvent paraEvent) {
FileDialog tempDialog = new FileDialog(GUICommon.mainFrame,
"Select a file");
tempDialog.setVisible(true);
if (tempDialog.getDirectory() == null) {
setText("");
return;
}//Of if
String directoryName = tempDialog.getDirectory();
String tempFilename = directoryName + tempDialog.getFile();
//System.out.println("tempFilename = " + tempFilename);
setText(tempFilename);
}// Of actionPerformed
/**
**********************************
* Implement FocusListenter.
*
* @param paraEvent
* The event is unimportant.
**********************************
*/
public void focusGained(FocusEvent paraEvent) {
}// Of focusGained
/**
**********************************
* Implement FocusListenter.
*
* @param paraEvent
* The event is unimportant.
**********************************
*/
public void focusLost(FocusEvent paraEvent) {
// System.out.println("Focus lost exists.");
String tempString = getText();
if ((tempString.equals("unspecified"))
|| (tempString.equals("")))
return;
File tempFile = new File(tempString);
if (!tempFile.exists()) {
ErrorDialog.errorDialog.setMessageAndShow("File \"" + tempString
+ "\" not exists. Please check.");
requestFocus();
setText("");
}
}// Of focusLost
}// Of class FilenameField
1.用了 GridLayout 和 BorderLayout 来组织控件.
2.按下 OK 执行 actionPerformed.
package xjx_GUI;
import java.awt.*;
import java.awt.event.*;
import java.util.Date;
import xjx.FullAnn;
public class AnnMain implements ActionListener {
/**
* Select the arff file.
*/
private FilenameField arffFilenameField;
/**
* The setting of alpha.
*/
private DoubleField alphaField;
/**
* The setting of alpha.
*/
private DoubleField betaField;
/**
* The setting of alpha.
*/
private DoubleField gammaField;
/**
* Layer nodes, such as "4, 8, 8, 3".
*/
private TextField layerNodesField;
/**
* Activators, such as "ssa".
*/
private TextField activatorField;
/**
* The number of training rounds.
*/
private IntegeField roundsField;
/**
* The learning rate.
*/
private DoubleField learningRateField;
/**
* The mobp.
*/
private DoubleField mobpField;
/**
* The message area.
*/
private TextArea messageTextArea;
/**
***************************
* The only constructor.
***************************
*/
public AnnMain() {
// A simple frame to contain dialogs.
Frame mainFrame = new Frame();
mainFrame.setTitle("ANN");
// The top part: select arff file.
arffFilenameField = new FilenameField(30);
arffFilenameField.setText("d:/data/iris.arff");
Button browseButton = new Button(" Browse ");
browseButton.addActionListener(new ActionListener() {
});
browseButton.addActionListener(new ActionListener() {
});
browseButton.addActionListener(arffFilenameField);
Panel sourceFilePanel = new Panel();
sourceFilePanel.add(new Label("The .arff file:"));
sourceFilePanel.add(arffFilenameField);
sourceFilePanel.add(browseButton);
// Setting panel.
Panel settingPanel = new Panel();
settingPanel.setLayout(new GridLayout(3, 6));
settingPanel.add(new Label("alpha"));
alphaField = new DoubleField("0.01");
settingPanel.add(alphaField);
settingPanel.add(new Label("beta"));
betaField = new DoubleField("0.02");
settingPanel.add(betaField);
settingPanel.add(new Label("gamma"));
gammaField = new DoubleField("0.03");
settingPanel.add(gammaField);
settingPanel.add(new Label("layer nodes"));
layerNodesField = new TextField("4, 8, 8, 3");
settingPanel.add(layerNodesField);
settingPanel.add(new Label("activators"));
activatorField = new TextField("sss");
settingPanel.add(activatorField);
settingPanel.add(new Label("training rounds"));
roundsField = new IntegerField("5000");
settingPanel.add(roundsField);
settingPanel.add(new Label("learning rate"));
learningRateField = new DoubleField("0.01");
settingPanel.add(learningRateField);
settingPanel.add(new Label("mobp"));
mobpField = new DoubleField("0.5");
settingPanel.add(mobpField);
Panel topPanel = new Panel();
topPanel.setLayout(new BorderLayout());
topPanel.add(BorderLayout.NORTH, sourceFilePanel);
topPanel.add(BorderLayout.CENTER, settingPanel);
messageTextArea = new TextArea(50, 40);
// The bottom part: ok and exit
Button okButton = new Button(" OK ");
okButton.addActionListener(this);
// DialogCloser dialogCloser = new DialogCloser(this);
Button exitButton = new Button(" Exit ");
// cancelButton.addActionListener(dialogCloser);
exitButton.addActionListener(ApplicationShutdown.applicationShutdown);
Button helpButton = new Button(" Help ");
helpButton.setSize(20, 10);
helpButton.addActionListener(new HelpDialog("ANN", "src/machinelearning/gui/help.txt"));
Panel okPanel = new Panel();
okPanel.add(okButton);
okPanel.add(exitButton);
okPanel.add(helpButton);
mainFrame.setLayout(new BorderLayout());
mainFrame.add(BorderLayout.NORTH, topPanel);
mainFrame.add(BorderLayout.CENTER, messageTextArea);
mainFrame.add(BorderLayout.SOUTH, okPanel);
mainFrame.setSize(600, 500);
mainFrame.setLocation(100, 100);
mainFrame.addWindowListener(ApplicationShutdown.applicationShutdown);
mainFrame.setBackground(GUICommon.MY_COLOR);
mainFrame.setVisible(true);
}// Of the constructor
/**
***************************
* Read the arff file.
***************************
*/
public void actionPerformed(ActionEvent ae) {
String tempFilename = arffFilenameField.getText();
// Read the layers nodes.
String tempString = layerNodesField.getText().trim();
int[] tempLayerNodes = null;
try {
tempLayerNodes = stringToIntArray(tempString);
} catch (Exception ee) {
ErrorDialog.errorDialog.setMessageAndShow(ee.toString());
return;
} // Of try
double tempLearningRate = learningRateField.getValue();
double tempMobp = mobpField.getValue();
String tempActivators = activatorField.getText().trim();
FullAnn tempNetwork = new FullAnn(tempFilename, tempLayerNodes, tempLearningRate, tempMobp,
tempActivators);
int tempRounds = roundsField.getValue();
long tempStartTime = new Date().getTime();
for (int i = 0; i < tempRounds; i++) {
tempNetwork.train();
} // Of for n
long tempEndTime = new Date().getTime();
messageTextArea.append("\r\nSummary:\r\n");
messageTextArea.append("Trainng time: " + (tempEndTime - tempStartTime) + "ms.\r\n");
double tempAccuray = tempNetwork.test();
messageTextArea.append("Accuracy: " + tempAccuray + "\r\n");
messageTextArea.append("End.");
}// Of actionPerformed
/**
**********************************
* Convert a string with commas into an int array.
*
* @param paraString
* The source string
* @return An int array.
* @throws Exception
* Exception for illegal data.
**********************************
*/
public static int[] stringToIntArray(String paraString) throws Exception {
int tempCounter = 1;
for (int i = 0; i < paraString.length(); i++) {
if (paraString.charAt(i) == ',') {
tempCounter++;
} // Of if
} // Of for i
int[] resultArray = new int[tempCounter];
String tempRemainingString = new String(paraString) + ",";
String tempString;
for (int i = 0; i < tempCounter; i++) {
tempString = tempRemainingString.substring(0, tempRemainingString.indexOf(",")).trim();
if (tempString.equals("")) {
throw new Exception("Blank is unsupported");
} // Of if
resultArray[i] = Integer.parseInt(tempString);
tempRemainingString = tempRemainingString
.substring(tempRemainingString.indexOf(",") + 1);
} // Of for i
return resultArray;
}// Of stringToIntArray
/**
***************************
* The entrance method.
*
* @param args
* The parameters.
***************************
*/
public static void main(String args[]) {
new AnnMain();
}// Of main
}// Of class AnnMain
1.从监听机制、接口等角度, 分析在 GUI 上的各种操作分别会触发哪些代码;
2.总结基础的人工神经网络.
1.Java事件监听机制
在上述的程序中,其中菜单条,菜单 项,按钮等都是对象,当我们单击对象时,应该能够完成一些任务.例如在程序中通过鼠标操作时,单击,双击,鼠标移入,鼠标移出.能够执行一些任务,在 Java中我们可以使用事件监听机制,在Java的事件监听机制中 ,当事件发生时(点击按钮,移动鼠标等,关闭窗口)会被一类对象发现并处理.
用户动作 | 源对象 | 触发的事件类型 |
---|---|---|
点击按钮 | JButton | ActionEvent |
文本域按回车 | JTextField | ActionEvent |
窗口打开,关闭,最小化,关闭 | Window | WindowEvent |
单击,双击,移动,鼠标 | Component | MouseEvent |
点击单选框 | JradioButton | ItemEvent ActionEvent |
点击复选框 | JcheckBox | ItemEvent ActionEvent |
Java中,对象表示的每个事件都是由java.util中EventObject类的子类,
例如: MouseEvent: 表示鼠标的动作,例如移动光标,单击,双击
KeyEvent: 表示键盘上的按键.
ActionEvent表示用户采取的用户界面操作,例如点击屏幕上的按钮.
2.基础人工神经网络总结
BP神经网络:
迭代算法,随机设定初值,计算当前网络的输出,根据当前输出和lable直接的差去改变前面各层的参数,直到收敛;
缺点:
梯度越来越稀疏,从顶层越往下,误差校正信号越老越小;
收敛到局部最优,尤其是从远离最优区域开始的时候(随机值初始化导致);
一般只能用有标签的数据训练,但大部分数据没有标签;