日撸 Java 三百行(71-80天,BP 神经网络)

目录

总述
01-10天,基本语法
11-20天,线性数据结构
21-30天,树与二叉树
31-40天,图
41-50天,查找与排序
51-60天,kNN 与 NB
61-70天,决策树与集成学习
71-80天,BP 神经网络
81-90天,CNN 卷积神经网络

第 71 天: BP神经网络基础类 (数据读取与基本结构)

  1. 这是神经网络一统江湖的时代.
    1.1 别人给的代码只有 70 行,不知道怎么回事,我改一改就到了 300+.
    1.2 今天这个程序是为了复用性而强行拆解获得的.
  2. 神经网络的基础机制
    2.1 前向 forward 预测
    2.1.1 加权和
    神经网络只做一件事情:属性提取
    2.1.2 激活函数
    改变线性,否则多层与一层等价
    2.2 后向 backpropagation 调整权重
    梯度下降:问责机制
    权值初始化为随机值
    类别为 0, 输出的标准答案是 [1, 0, 0],但实际输出是 [0.3, 0.5, 0.6], 错误就是 [0.7, -0.5, -0.6]
    2.3 固定网络结构
    层数与每层节点数人为设置:网络调参师
  3. 数据
    3.1 实例
    一维数组(m 维向量)
    3.2 单层权重
    二维数组
    3.3 所有的权重
    三维数组
    参差不齐的数组
  4. 效果
    4.1 深度越大越复杂,对于简单数据(属性较少)有可能过拟合
    4.2 宽度网络与深度网络等价
package machinelearning.ann;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;

import weka.core.Instances;

/**
 * General ANN. Two methods are abstract: forward and backPropagation.
 * 
 * @author Fan Min [email protected].
 */
public abstract class GeneralAnn {

	/**
	 * The whole dataset.
	 */
	Instances dataset;

	/**
	 * Number of layers. It is counted according to nodes instead of edges.
	 */
	int numLayers;

	/**
	 * The number of nodes for each layer, e.g., [3, 4, 6, 2] means that there
	 * are 3 input nodes (conditional attributes), 2 hidden layers with 4 and 6
	 * nodes, respectively, and 2 class values (binary classification).
	 */
	int[] layerNumNodes;

	/**
	 * Momentum coefficient.
	 */
	public double mobp;

	/**
	 * Learning rate.
	 */
	public double learningRate;

	/**
	 * For random number generation.
	 */
	Random random = new Random();

	/**
	 ********************
	 * The first constructor.
	 * 
	 * @param paraFilename
	 *            The arff filename.
	 * @param paraLayerNumNodes
	 *            The number of nodes for each layer (may be different).
	 * @param paraLearningRate
	 *            Learning rate.
	 * @param paraMobp
	 *            Momentum coefficient.
	 ********************
	 */
	public GeneralAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,
			double paraMobp) {
		// Step 1. Read data.
		try {
			FileReader tempReader = new FileReader(paraFilename);
			dataset = new Instances(tempReader);
			// The last attribute is the decision class.
			dataset.setClassIndex(dataset.numAttributes() - 1);
			tempReader.close();
		} catch (Exception ee) {
			System.out.println("Error occurred while trying to read \'" + paraFilename
					+ "\' in GeneralAnn constructor.\r\n" + ee);
			System.exit(0);
		} // Of try

		// Step 2. Accept parameters.
		layerNumNodes = paraLayerNumNodes;
		numLayers = layerNumNodes.length;
		// Adjust if necessary.
		layerNumNodes[0] = dataset.numAttributes() - 1;
		layerNumNodes[numLayers - 1] = dataset.numClasses();
		learningRate = paraLearningRate;
		mobp = paraMobp;	
	}//Of the first constructor	
	
	/**
	 ********************
	 * Forward prediction.
	 * 
	 * @param paraInput
	 *            The input data of one instance.
	 * @return The data at the output end.
	 ********************
	 */
	public abstract double[] forward(double[] paraInput);

	/**
	 ********************
	 * Back propagation.
	 * 
	 * @param paraTarget
	 *            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
	 *            
	 ********************
	 */
	public abstract void backPropagation(double[] paraTarget);

	/**
	 ********************
	 * Train using the dataset.
	 ********************
	 */
	public void train() {
		double[] tempInput = new double[dataset.numAttributes() - 1];
		double[] tempTarget = new double[dataset.numClasses()];
		for (int i = 0; i < dataset.numInstances(); i++) {
			// Fill the data.
			for (int j = 0; j < tempInput.length; j++) {
				tempInput[j] = dataset.instance(i).value(j);
			} // Of for j

			// Fill the class label.
			Arrays.fill(tempTarget, 0);
			tempTarget[(int) dataset.instance(i).classValue()] = 1;

			// Train with this instance.
			forward(tempInput);
			backPropagation(tempTarget);
		} // Of for i
	}// Of train

	/**
	 ********************
	 * Get the index corresponding to the max value of the array.
	 * 
	 * @return the index.
	 ********************
	 */
	public static int argmax(double[] paraArray) {
		int resultIndex = -1;
		double tempMax = -1e10;
		for (int i = 0; i < paraArray.length; i++) {
			if (tempMax < paraArray[i]) {
				tempMax = paraArray[i];
				resultIndex = i;
			} // Of if
		} // Of for i

		return resultIndex;
	}// Of argmax

	/**
	 ********************
	 * Test using the dataset.
	 * 
	 * @return The precision.
	 ********************
	 */
	public double test() {
		double[] tempInput = new double[dataset.numAttributes() - 1];

		double tempNumCorrect = 0;
		double[] tempPrediction;
		int tempPredictedClass = -1;

		for (int i = 0; i < dataset.numInstances(); i++) {
			// Fill the data.
			for (int j = 0; j < tempInput.length; j++) {
				tempInput[j] = dataset.instance(i).value(j);
			} // Of for j

			// Train with this instance.
			tempPrediction = forward(tempInput);
			//System.out.println("prediction: " + Arrays.toString(tempPrediction));
			tempPredictedClass = argmax(tempPrediction);
			if (tempPredictedClass == (int) dataset.instance(i).classValue()) {
				tempNumCorrect++;
			} // Of if
		} // Of for i

		System.out.println("Correct: " + tempNumCorrect + " out of " + dataset.numInstances());

		return tempNumCorrect / dataset.numInstances();
	}// Of test
}//Of class GeneralAnn

第 72 天: 固定激活函数的BP神经网络 (1. 网络结构理解)

网络结构和数据通过几个数组确定. 需要结合程序的运行来理解它们.

  1. layerNumNodes 表示网络基本结构. 如: [3, 4, 6, 2] 表示:
    a) 输入端口有 3 个,即数据有 3 个条件属性. 如果与实际数据不符, 代码会自动纠正, 见 GeneralAnn.java 81 行.
    b) 输出端口有 2 个, 即数据的决策类别数为 2. 如果与实际数据不符, 代码会自动纠正, 见 GeneralAnn.java 82 行. 对于分类问题, 数据是哪个类别, 对应于输出值最大的端口.
    c) 有两个中间层, 分别为 4 个和 6 个节点.
  2. layerNodeValues 表示各网络节点的值. 如上例, 网络的节点有 4 层, 即 layerNodeValues.length 为 4. 总结点数为 3 + 4 + 6 + 2 − 15 3 + 4 + 6 + 2 - 15 3+4+6+215 个, 即 layerNodeValues[0].length = 3, layerNodeValues[1].length = 4, layerNodeValues[2].length = 6, layerNodeValues[3].length = 2. Java 支持这种不规则的矩阵 (不同行的列数不同), 因为二维矩阵被当作一维向量的一维向量.
  3. layerNodeErrors 表示各网络节点上的误差. 该数组大小于 layerNodeValues 一致.
  4. edgeWeights 表示各条边的权重. 由于两层之间的边为多对多关系 (二维数组), 多个层的边就成了三维数组. 例如, 上面例子的第 0 层就应该有 ( 3 + 1 ) × 4 = 16 (3+1) \times 4 = 16 (3+1)×4=16 条边, 这里 + 1 +1 +1 表示有偏移量 offset. 总共的层数为 4 − 1 = 3 4 - 1 = 3 41=3, 即边的层数要比节点层数少 1. 这也是写程序过程中非常容易出错的地方.
  5. edgeWeightsDelta 与 edgeWeights 具有相同大小, 它辅助后者进行调整.
    下面才是核心代码.
package machinelearning.ann;

/**
 * Back-propagation neural networks. The code comes from
 * https://mp.weixin.qq.com
 * /s?__biz=MjM5MjAwODM4MA==&mid=402665740&idx=1&sn=18d84d72934e59ca8bcd828782172667
 * 
 * @author 彭渊 revised by [email protected]
 */

public class SimpleAnn extends GeneralAnn{

	/**
	 * The value of each node that changes during the forward process. The first
	 * dimension stands for the layer, and the second stands for the node.
	 */
	public double[][] layerNodeValues;

	/**
	 * The error on each node that changes during the back-propagation process.
	 * The first dimension stands for the layer, and the second stands for the
	 * node.
	 */
	public double[][] layerNodeErrors;

	/**
	 * The weights of edges. The first dimension stands for the layer, the
	 * second stands for the node index of the layer, and the third dimension
	 * stands for the node index of the next layer.
	 */
	public double[][][] edgeWeights;

	/**
	 * The change of edge weights. It has the same size as edgeWeights.
	 */
	public double[][][] edgeWeightsDelta;

	/**
	 ********************
	 * The first constructor.
	 * 
	 * @param paraFilename
	 *            The arff filename.
	 * @param paraLayerNumNodes
	 *            The number of nodes for each layer (may be different).
	 * @param paraLearningRate
	 *            Learning rate.
	 * @param paraMobp
	 *            Momentum coefficient.
	 ********************
	 */
	public SimpleAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,
			double paraMobp) {
		super(paraFilename, paraLayerNumNodes, paraLearningRate, paraMobp);

		// Step 1. Across layer initialization.
		layerNodeValues = new double[numLayers][];
		layerNodeErrors = new double[numLayers][];
		edgeWeights = new double[numLayers - 1][][];
		edgeWeightsDelta = new double[numLayers - 1][][];

		// Step 2. Inner layer initialization.
		for (int l = 0; l < numLayers; l++) {
			layerNodeValues[l] = new double[layerNumNodes[l]];
			layerNodeErrors[l] = new double[layerNumNodes[l]];

			// One less layer because each edge crosses two layers.
			if (l + 1 == numLayers) {
				break;
			} // of if

			// In layerNumNodes[l] + 1, the last one is reserved for the offset.
			edgeWeights[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];
			edgeWeightsDelta[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];
			for (int j = 0; j < layerNumNodes[l] + 1; j++) {
				for (int i = 0; i < layerNumNodes[l + 1]; i++) {
					// Initialize weights.
					edgeWeights[l][j][i] = random.nextDouble();
				} // Of for i
			} // Of for j
		} // Of for l
	}// Of the constructor

	/**
	 ********************
	 * Forward prediction.
	 * 
	 * @param paraInput
	 *            The input data of one instance.
	 * @return The data at the output end.
	 ********************
	 */
	public double[] forward(double[] paraInput) {
		// Initialize the input layer.
		for (int i = 0; i < layerNodeValues[0].length; i++) {
			layerNodeValues[0][i] = paraInput[i];
		} // Of for i

		// Calculate the node values of each layer.
		double z;
		for (int l = 1; l < numLayers; l++) {
			for (int j = 0; j < layerNodeValues[l].length; j++) {
				// Initialize according to the offset, which is always +1
				z = edgeWeights[l - 1][layerNodeValues[l - 1].length][j];
				// Weighted sum on all edges for this node.
				for (int i = 0; i < layerNodeValues[l - 1].length; i++) {
					z += edgeWeights[l - 1][i][j] * layerNodeValues[l - 1][i];
				} // Of for i

				// Sigmoid activation.
				// This line should be changed for other activation functions.
				layerNodeValues[l][j] = 1 / (1 + Math.exp(-z));
			} // Of for j
		} // Of for l

		return layerNodeValues[numLayers - 1];
	}// Of forward

	/**
	 ********************
	 * Back propagation and change the edge weights.
	 * 
	 * @param paraTarget
	 *            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
	 ********************
	 */
	public void backPropagation(double[] paraTarget) {
		// Step 1. Initialize the output layer error.
		int l = numLayers - 1;
		for (int j = 0; j < layerNodeErrors[l].length; j++) {
			layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j])
					* (paraTarget[j] - layerNodeValues[l][j]);
		} // Of for j

		// Step 2. Back-propagation even for l == 0
		while (l > 0) {
			l--;
			// Layer l, for each node.
			for (int j = 0; j < layerNumNodes[l]; j++) {
				double z = 0.0;
				// For each node of the next layer.
				for (int i = 0; i < layerNumNodes[l + 1]; i++) {
					if (l > 0) {
						z += layerNodeErrors[l + 1][i] * edgeWeights[l][j][i];
					} // Of if

					// Weight adjusting.
					edgeWeightsDelta[l][j][i] = mobp * edgeWeightsDelta[l][j][i]
							+ learningRate * layerNodeErrors[l + 1][i] * layerNodeValues[l][j];
					edgeWeights[l][j][i] += edgeWeightsDelta[l][j][i];
					if (j == layerNumNodes[l] - 1) {
						// Weight adjusting for the offset part.
						edgeWeightsDelta[l][j + 1][i] = mobp * edgeWeightsDelta[l][j + 1][i]
								+ learningRate * layerNodeErrors[l + 1][i];
						edgeWeights[l][j + 1][i] += edgeWeightsDelta[l][j + 1][i];
					} // Of if
				} // Of for i

				// Record the error according to the differential of Sigmoid.
				// This line should be changed for other activation functions.
				layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * z;
			} // Of for j
		} // Of while
	}// Of backPropagation

	/**
	 ********************
	 * Test the algorithm.
	 ********************
	 */
	public static void main(String[] args) {
		int[] tempLayerNodes = { 4, 8, 8, 3 };
		SimpleAnn tempNetwork = new SimpleAnn("D:/data/iris.arff", tempLayerNodes, 0.01,
				0.6);

		for (int round = 0; round < 5000; round++) {
			tempNetwork.train();
		} // Of for n

		double tempAccuracy = tempNetwork.test();
		System.out.println("The accuracy is: " + tempAccuracy);
	}// Of main
}// Of class SimpleAnn

第 73 天: 固定激活函数的BP神经网络 (2. 训练与测试过程理解)

  1. Forward 就是利用当前网络对一条数据进行预测的过程.
  2. BackPropagation 就是根据误差进行网络权重调节的过程.
  3. 训练的时候需要前向与后向, 测试的时候只需要前向.
  4. 这里只实现了 sigmoid 激活函数, 反向传播时的导数与正向传播时的激活函数相对应. 如果要换激活函数, 需要两个地方同时换.

第 74 天: 通用BP神经网络 (1. 集中管理激活函数)

激活函数是神经网络的核心. 今天的代码虽然有 300 行, 但是很简单.

  1. 激活与求导是一个, 前者用于 forward, 后者用于 back-propagation.
  2. 有很多的激活函数, 它们的设计有相应准则, 如分段可导.
  3. 查资料补充几个未实现的激活函数.
  4. 进一步测试.
package machinelearning.ann;

/**
 * Activator.
 * 
 * @author Fan Min [email protected].
 */

public class Activator {
	/**
	 * Arc tan.
	 */
	public final char ARC_TAN = 'a';

	/**
	 * Elu.
	 */
	public final char ELU = 'e';

	/**
	 * Gelu.
	 */
	public final char GELU = 'g';

	/**
	 * Hard logistic.
	 */
	public final char HARD_LOGISTIC = 'h';

	/**
	 * Identity.
	 */
	public final char IDENTITY = 'i';

	/**
	 * Leaky relu, also known as parametric relu.
	 */
	public final char LEAKY_RELU = 'l';

	/**
	 * Relu.
	 */
	public final char RELU = 'r';

	/**
	 * Soft sign.
	 */
	public final char SOFT_SIGN = 'o';

	/**
	 * Sigmoid.
	 */
	public final char SIGMOID = 's';

	/**
	 * Tanh.
	 */
	public final char TANH = 't';

	/**
	 * Soft plus.
	 */
	public final char SOFT_PLUS = 'u';

	/**
	 * Swish.
	 */
	public final char SWISH = 'w';

	/**
	 * The activator.
	 */
	private char activator;

	/**
	 * Alpha for elu.
	 */
	double alpha;

	/**
	 * Beta for leaky relu.
	 */
	double beta;

	/**
	 * Gamma for leaky relu.
	 */
	double gamma;

	/**
	 *********************
	 * The first constructor.
	 * 
	 * @param paraActivator
	 *            The activator.
	 *********************
	 */
	public Activator(char paraActivator) {
		activator = paraActivator;
	}// Of the first constructor

	/**
	 *********************
	 * Setter.
	 *********************
	 */
	public void setActivator(char paraActivator) {
		activator = paraActivator;
	}// Of setActivator

	/**
	 *********************
	 * Getter.
	 *********************
	 */
	public char getActivator() {
		return activator;
	}// Of getActivator

	/**
	 *********************
	 * Setter.
	 *********************
	 */
	void setAlpha(double paraAlpha) {
		alpha = paraAlpha;
	}// Of setAlpha

	/**
	 *********************
	 * Setter.
	 *********************
	 */
	void setBeta(double paraBeta) {
		beta = paraBeta;
	}// Of setBeta

	/**
	 *********************
	 * Setter.
	 *********************
	 */
	void setGamma(double paraGamma) {
		gamma = paraGamma;
	}// Of setGamma

	/**
	 *********************
	 * Activate according to the activation function.
	 *********************
	 */
	public double activate(double paraValue) {
		double resultValue = 0;
		switch (activator) {
		case ARC_TAN:
			resultValue = Math.atan(paraValue);
			break;
		case ELU:
			if (paraValue >= 0) {
				resultValue = paraValue;
			} else {
				resultValue = alpha * (Math.exp(paraValue) - 1);
			} // Of if
			break;
		// case GELU:
		// resultValue = ?;
		// break;
		// case HARD_LOGISTIC:
		// resultValue = ?;
		// break;
		case IDENTITY:
			resultValue = paraValue;
			break;
		case LEAKY_RELU:
			if (paraValue >= 0) {
				resultValue = paraValue;
			} else {
				resultValue = alpha * paraValue;
			} // Of if
			break;
		case SOFT_SIGN:
			if (paraValue >= 0) {
				resultValue = paraValue / (1 + paraValue);
			} else {
				resultValue = paraValue / (1 - paraValue);
			} // Of if
			break;
		case SOFT_PLUS:
			resultValue = Math.log(1 + Math.exp(paraValue));
			break;
		case RELU:
			if (paraValue >= 0) {
				resultValue = paraValue;
			} else {
				resultValue = 0;
			} // Of if
			break;
		case SIGMOID:
			resultValue = 1 / (1 + Math.exp(-paraValue));
			break;
		case TANH:
			resultValue = 2 / (1 + Math.exp(-2 * paraValue)) - 1;
			break;
		// case SWISH:
		// resultValue = ?;
		// break;
		default:
			System.out.println("Unsupported activator: " + activator);
			System.exit(0);
		}// Of switch

		return resultValue;
	}// Of activate

	/**
	 *********************
	 * Derive according to the activation function. Some use x while others use
	 * f(x).
	 * 
	 * @param paraValue
	 *            The original value x.
	 * @param paraActivatedValue
	 *            f(x).
	 *********************
	 */
	public double derive(double paraValue, double paraActivatedValue) {
		double resultValue = 0;
		switch (activator) {
		case ARC_TAN:
			resultValue = 1 / (paraValue * paraValue + 1);
			break;
		case ELU:
			if (paraValue >= 0) {
				resultValue = 1;
			} else {
				resultValue = alpha * (Math.exp(paraValue) - 1) + alpha;
			} // Of if
			break;
		// case GELU:
		// resultValue = ?;
		// break;
		// case HARD_LOGISTIC:
		// resultValue = ?;
		// break;
		case IDENTITY:
			resultValue = 1;
			break;
		case LEAKY_RELU:
			if (paraValue >= 0) {
				resultValue = 1;
			} else {
				resultValue = alpha;
			} // Of if
			break;
		case SOFT_SIGN:
			if (paraValue >= 0) {
				resultValue = 1 / (1 + paraValue) / (1 + paraValue);
			} else {
				resultValue = 1 / (1 - paraValue) / (1 - paraValue);
			} // Of if
			break;
		case SOFT_PLUS:
			resultValue = 1 / (1 + Math.exp(-paraValue));
			break;
		case RELU: // Updated
			if (paraValue >= 0) {
				resultValue = 1;
			} else {
				resultValue = 0;
			} // Of if
			break;
		case SIGMOID: // Updated
			resultValue = paraActivatedValue * (1 - paraActivatedValue);
			break;
		case TANH: // Updated
			resultValue = 1 - paraActivatedValue * paraActivatedValue;
			break;
		// case SWISH:
		// resultValue = ?;
		// break;
		default:
			System.out.println("Unsupported activator: " + activator);
			System.exit(0);
		}// Of switch

		return resultValue;
	}// Of derive

	/**
	 *********************
	 * Overrides the method claimed in Object.
	 *********************
	 */
	public String toString() {
		String resultString = "Activator with function '" + activator + "'";
		resultString += "\r\n alpha = " + alpha + ", beta = " + beta + ", gamma = " + gamma;

		return resultString;
	}// Of toString

	/**
	 ********************
	 * Test the class.
	 ********************
	 */
	public static void main(String[] args) {
		Activator tempActivator = new Activator('s');
		double tempValue = 0.6;
		double tempNewValue;
		tempNewValue = tempActivator.activate(tempValue);
		System.out.println("After activation: " + tempNewValue);

		tempNewValue = tempActivator.derive(tempValue, tempNewValue);
		System.out.println("After derive: " + tempNewValue);
	}// Of main
}// Of class Activator

第 75 天: 通用BP神经网络 (2. 单层实现)

  1. 仅实现单层 ANN.
  2. 可以有自己的激活函数.
  3. 正向计算输出, 反向计算误差并调整权值.
package machinelearning.ann;

import java.util.Arrays;
import java.util.Random;

/**
 * Ann layer.
 * 
 * @author Fan Min [email protected].
 */
public class AnnLayer {

	/**
	 * The number of input.
	 */
	int numInput;

	/**
	 * The number of output.
	 */
	int numOutput;

	/**
	 * The learning rate.
	 */
	double learningRate;

	/**
	 * The mobp.
	 */
	double mobp;

	/**
	 * The weight matrix.
	 */
	double[][] weights;

	/**
	 * The delta weight matrix.
	 */
	double[][] deltaWeights;

	/**
	 * Error on nodes.
	 */
	double[] errors;

	/**
	 * The inputs.
	 */
	double[] input;

	/**
	 * The outputs.
	 */
	double[] output;

	/**
	 * The output after activate.
	 */
	double[] activatedOutput;

	/**
	 * The inputs.
	 */
	Activator activator;

	/**
	 * The inputs.
	 */
	Random random = new Random();

	/**
	 *********************
	 * The first constructor.
	 * 
	 * @param paraActivator
	 *            The activator.
	 *********************
	 */
	public AnnLayer(int paraNumInput, int paraNumOutput, char paraActivator,
			double paraLearningRate, double paraMobp) {
		numInput = paraNumInput;
		numOutput = paraNumOutput;
		learningRate = paraLearningRate;
		mobp = paraMobp;

		weights = new double[numInput + 1][numOutput];
		deltaWeights = new double[numInput + 1][numOutput];
		for (int i = 0; i < numInput + 1; i++) {
			for (int j = 0; j < numOutput; j++) {
				weights[i][j] = random.nextDouble();
			} // Of for j
		} // Of for i

		errors = new double[numInput];

		input = new double[numInput];
		output = new double[numOutput];
		activatedOutput = new double[numOutput];

		activator = new Activator(paraActivator);
	}// Of the first constructor

	/**
	 ********************
	 * Set parameters for the activator.
	 * 
	 * @param paraAlpha
	 *            Alpha. Only valid for certain types.
	 * @param paraBeta
	 *            Beta.
	 * @param paraAlpha
	 *            Alpha.
	 ********************
	 */
	public void setParameters(double paraAlpha, double paraBeta, double paraGamma) {
		activator.setAlpha(paraAlpha);
		activator.setBeta(paraBeta);
		activator.setGamma(paraGamma);
	}// Of setParameters

	/**
	 ********************
	 * Forward prediction.
	 * 
	 * @param paraInput
	 *            The input data of one instance.
	 * @return The data at the output end.
	 ********************
	 */
	public double[] forward(double[] paraInput) {
		//System.out.println("Ann layer forward " + Arrays.toString(paraInput));
		// Copy data.
		for (int i = 0; i < numInput; i++) {
			input[i] = paraInput[i];
		} // Of for i

		// Calculate the weighted sum for each output.
		for (int i = 0; i < numOutput; i++) {
			output[i] = weights[numInput][i];
			for (int j = 0; j < numInput; j++) {
				output[i] += input[j] * weights[j][i];
			} // Of for j

			activatedOutput[i] = activator.activate(output[i]);
		} // Of for i

		return activatedOutput;
	}// Of forward

	/**
	 ********************
	 * Back propagation and change the edge weights.
	 * 
	 * @param paraTarget
	 *            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
	 ********************
	 */
	public double[] backPropagation(double[] paraErrors) {
		//Step 1. Adjust the errors.
		for (int i = 0; i < paraErrors.length; i++) {
			paraErrors[i] = activator.derive(output[i], activatedOutput[i]) * paraErrors[i];
		}//Of for i
		
		//Step 2. Compute current errors.
		for (int i = 0; i < numInput; i++) {
			errors[i] = 0;
			for (int j = 0; j < numOutput; j++) {
				errors[i] += paraErrors[j] * weights[i][j];
				deltaWeights[i][j] = mobp * deltaWeights[i][j]
						+ learningRate * paraErrors[j] * input[i];
				weights[i][j] += deltaWeights[i][j];
			} // Of for j
		} // Of for i

		for (int j = 0; j < numOutput; j++) {
			deltaWeights[numInput][j] = mobp * deltaWeights[numInput][j] + learningRate * paraErrors[j];
			weights[numInput][j] += deltaWeights[numInput][j];
		} // Of for j

		return errors;
	}// Of backPropagation

	/**
	 ********************
	 * I am the last layer, set the errors.
	 * 
	 * @param paraTarget
	 *            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
	 ********************
	 */
	public double[] getLastLayerErrors(double[] paraTarget) {
		double[] resultErrors = new double[numOutput];
		for (int i = 0; i < numOutput; i++) {
			resultErrors[i] = (paraTarget[i] - activatedOutput[i]);
		} // Of for i

		return resultErrors;
	}// Of getLastLayerErrors

	/**
	 ********************
	 * Show me.
	 ********************
	 */
	public String toString() {
		String resultString = "";
		resultString += "Activator: " + activator;
		resultString += "\r\n weights = " + Arrays.deepToString(weights);
		return resultString;
	}// Of toString

	/**
	 ********************
	 * Unit test.
	 ********************
	 */
	public static void unitTest() {
		AnnLayer tempLayer = new AnnLayer(2, 3, 's', 0.01, 0.1);
		double[] tempInput = { 1, 4 };

		System.out.println(tempLayer);

		double[] tempOutput = tempLayer.forward(tempInput);
		System.out.println("Forward, the output is: " + Arrays.toString(tempOutput));

		double[] tempError = tempLayer.backPropagation(tempOutput);
		System.out.println("Back propagation, the error is: " + Arrays.toString(tempError));
	}// Of unitTest

	/**
	 ********************
	 * Test the algorithm.
	 ********************
	 */
	public static void main(String[] args) {
		unitTest();
	}// Of main
}// Of class AnnLayer

第 76 天: 通用BP神经网络 (3. 综合测试)

  1. 自己尝试其它的激活函数.
package machinelearning.ann;

import java.util.Arrays;

/**
 * Full ANN with a number of layers.
 * 
 * @author Fan Min [email protected].
 */
public class FullAnn extends GeneralAnn {

	/**
	 * The layers.
	 */
	AnnLayer[] layers;

	/**
	 ********************
	 * The first constructor.
	 * 
	 * @param paraFilename
	 *            The arff filename.
	 * @param paraLayerNumNodes
	 *            The number of nodes for each layer (may be different).
	 * @param paraLearningRate
	 *            Learning rate.
	 * @param paraMobp
	 *            Momentum coefficient.
	 * @param paraActivators The storing the activators of each layer.
	 ********************
	 */
	public FullAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,
			double paraMobp, String paraActivators) {
		super(paraFilename, paraLayerNumNodes, paraLearningRate, paraMobp);

		// Initialize layers.
		layers = new AnnLayer[numLayers - 1];
		for (int i = 0; i < layers.length; i++) {
			layers[i] = new AnnLayer(layerNumNodes[i], layerNumNodes[i + 1], paraActivators.charAt(i), paraLearningRate,
					paraMobp);
		} // Of for i
	}// Of the first constructor

	/**
	 ********************
	 * Forward prediction. This is just a stub and should be overwritten in the subclass.
	 * 
	 * @param paraInput
	 *            The input data of one instance.
	 * @return The data at the output end.
	 ********************
	 */
	public double[] forward(double[] paraInput) {
		double[] resultArray = paraInput;
		for(int i = 0; i < numLayers - 1; i ++) {
			resultArray = layers[i].forward(resultArray);
		}//Of for i
		
		return resultArray;
	}// Of forward

	/**
	 ********************
	 * Back propagation. This is just a stub and should be overwritten in the subclass.
	 * 
	 * @param paraTarget
	 *            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
	 *            
	 ********************
	 */
	public void backPropagation(double[] paraTarget) {
		double[] tempErrors = layers[numLayers - 2].getLastLayerErrors(paraTarget);
		for (int i = numLayers - 2; i >= 0; i--) {
			tempErrors = layers[i].backPropagation(tempErrors);
		}//Of for i
		
		return;
	}// Of backPropagation

	/**
	 ********************
	 * Show me.
	 ********************
	 */
	public String toString() {
		String resultString = "I am a full ANN with " + numLayers + " layers";
		return resultString;
	}// Of toString

	/**
	 ********************
	 * Test the algorithm.
	 ********************
	 */
	public static void main(String[] args) {
		int[] tempLayerNodes = { 4, 8, 8, 3 };
		FullAnn tempNetwork = new FullAnn("D:/data/iris.arff", tempLayerNodes, 0.01,
				0.6, "sss");

		for (int round = 0; round < 5000; round++) {
			tempNetwork.train();
		} // Of for n

		double tempAccuray = tempNetwork.test();
		System.out.println("The accuracy is: " + tempAccuray);
		System.out.println("FullAnn ends.");
	}// Of main	
}// Of class FullAnn

第 77 天: GUI (1. 对话框相关控件)

GUI 开始接触的时候可能会嫌繁琐, 受了了的话也可以直接拷贝程序 (以下几天均相同). 但不同的 GUI 大同小异, 多数代码是可以复用的. 实际上, 今天和明天的代码已经有 20 年以上的历史 (你没想错, 就这是我在学生时代写的), 只是为了这个贴子又重新整理了一下.
根据 GUI 也可以进一步理解 Java 的面向对象机制. 这些机制 (接口、监听者、异常) 可以参阅 Thinking in Java 等书籍, 也可以自悟 (反正我 1998年学 Java 时就没看书, 而是查阅的 JDK doc).

代码说明:

  1. ApplicationShowdown.java 仅用于退出图形用户界面 GUI.
  2. 只生成了一个静态的实例对象. 构造方法是 private 的, 不允许在该类之外 new. 这是一个有意思的小技巧.
package machinelearning.gui;

import java.awt.event.*;

/**
 * Shut down the application according to window action or button action.
 * @author Fan Min [email protected].
 */
public class ApplicationShutdown implements WindowListener, ActionListener {
   /**
    * Only one instance.
    */
	public static ApplicationShutdown applicationShutdown = new ApplicationShutdown();

	/**
	 *************************** 
	 * This constructor is private such that the only instance is generated here.
	 *************************** 
	 */
	private ApplicationShutdown() {
	}// Of ApplicationShutdown.

	/**
	 *************************** 
	 * Shutdown the system
	 *************************** 
	 */
	public void windowClosing(WindowEvent comeInWindowEvent) {
		System.exit(0);
	}// Of windowClosing.

	public void windowActivated(WindowEvent comeInWindowEvent) {
	}// Of windowActivated.

	public void windowClosed(WindowEvent comeInWindowEvent) {
	}// Of windowClosed.

	public void windowDeactivated(WindowEvent comeInWindowEvent) {
	}// Of windowDeactivated.

	public void windowDeiconified(WindowEvent comeInWindowEvent) {
	}// Of windowDeiconified.

	public void windowIconified(WindowEvent comeInWindowEvent) {
	}// Of windowIconified.

	public void windowOpened(WindowEvent comeInWindowEvent) {
	}// Of windowOpened.

	/**
    *************************
    *************************
    */
	public void actionPerformed(ActionEvent ee) {
		System.exit(0);
	}// Of actionPerformed.
}// Of class ApplicationShutdown

DialogCloser.java 用于关闭窗口, 而不是整个的 GUI.

package machinelearning.gui;

import java.awt.*;
import java.awt.event.*;

/**
 * Close the dialog.
 * @author Fan Min [email protected].
 */
public class DialogCloser extends WindowAdapter implements ActionListener {

	/**
	 * The dialog under control.
	 */
	private Dialog currentDialog;

	/**
	 *************************** 
	 * The first constructor.
	 *************************** 
	 */
	public DialogCloser() {
		super();
	}// Of the first constructor

	/**
	 *************************** 
	 * The second constructor.
	 * 
	 * @param paraDialog
	 *            the dialog under control
	 *************************** 
	 */
	public DialogCloser(Dialog paraDialog) {
		currentDialog = paraDialog;
	}// Of the second constructor

	/**
	 *************************** 
	 * Close the dialog which clicking the cross at the up-right corner of the window.
	 * 
	 * @param comeInWindowEvent
	 *            From it we can obtain which window sent the message because X
	 *            was used.
	 *************************** 
	 */
	public void windowClosing(WindowEvent paraWindowEvent) {
		paraWindowEvent.getWindow().dispose();
	}// Of windowClosing.

	/**
	 *************************** 
	 * Close the dialog while pushing an "OK" or "Cancel" button.
	 * 
	 * @param paraEvent
	 *            Not considered. 
	 *************************** 
	 */
	public void actionPerformed(ActionEvent paraEvent) {
		currentDialog.dispose();
	}// Of actionPerformed.
}// Of class DialogCloser

ErrorDialog.java 用于显示出错信息. 有了 GUI 我们可以不再使用 System.out.println.

package machinelearning.gui;

import java.awt.*;

/**
 * For error message.
 * @author Fan Min [email protected].
 */
public class ErrorDialog extends Dialog {

	/**
	 * Serial uid. Not quite useful.
	 */
	private static final long serialVersionUID = 124535235L;

	/**
	 * The ONLY ErrorDialog.
	 */
	public static ErrorDialog errorDialog = new ErrorDialog();

	/**
	 * The label containing the message to display.
	 */
	private TextArea messageTextArea;

	/**
	 *************************** 
	 * Display an error dialog and respective error message. Like other dialogs,
	 * this constructor is private, such that users can use only one dialog,
	 * i.e., ErrorDialog.errorDialog to display message. This is helpful for
	 * saving space (only one dialog) since we may need "many" dialogs.
	 *************************** 
	 */
	private ErrorDialog() {
		// This dialog is module.
		super(GUICommon.mainFrame, "Error", true);

		// Prepare for the dialog.
		messageTextArea = new TextArea();

		Button okButton = new Button("OK");
		okButton.setSize(20, 10);
		okButton.addActionListener(new DialogCloser(this));
		Panel okPanel = new Panel();
		okPanel.setLayout(new FlowLayout());
		okPanel.add(okButton);

		// Add TextArea and Button
		setLayout(new BorderLayout());
		add(BorderLayout.CENTER, messageTextArea);
		add(BorderLayout.SOUTH, okPanel);

		setLocation(200, 200);
		setSize(500, 200);
		addWindowListener(new DialogCloser());
		setVisible(false);
	}// Of constructor

	/**
	 *************************** 
	 * set message.
	 * 
	 * @param paramMessage
	 *            the new message
	 *************************** 
	 */
	public void setMessageAndShow(String paramMessage) {
		messageTextArea.setText(paramMessage);
		setVisible(true);
	}// Of setTitleAndMessage
}// Of class ErrorDialog

GUICommon.java 存储一些公用变量.

package machinelearning.gui;

import java.awt.*;
import javax.swing.*;

/**
 * Manage the GUI.
 * 
 * @author Fan Min [email protected].
 */
public class GUICommon extends Object {
	/**
	 * Only one main frame.
	 */
	public static Frame mainFrame = null;

	/**
	 * Only one main pane.
	 */
	public static JTabbedPane mainPane = null;

	/**
	 * For default project number.
	 */
	public static int currentProjectNumber = 0;

	/**
	 * Default font.
	 */
	public static final Font MY_FONT = new Font("Times New Romans", Font.PLAIN, 12);

	/**
	 * Default color
	 */
	public static final Color MY_COLOR = Color.lightGray;

	/**
	 *************************** 
	 * Set the main frame. This can be done only once at the initialzing stage.
	 * 
	 * @param paraFrame
	 *            the main frame of the GUI.
	 * @throws Exception
	 *             If the main frame is set more than once.
	 *************************** 
	 */
	public static void setFrame(Frame paraFrame) throws Exception {
		if (mainFrame == null) {
			mainFrame = paraFrame;
		} else {
			throw new Exception("The main frame can be set only ONCE!");
		} // Of if
	}// Of setFrame

	/**
	 *************************** 
	 * Set the main pane. This can be done only once at the initialzing stage.
	 * 
	 * @param paramPane
	 *            the main pane of the GUI.
	 * @throws Exception
	 *             If the main panel is set more than once.
	 *************************** 
	 */
	public static void setPane(JTabbedPane paramPane) throws Exception {
		if (mainPane == null) {
			mainPane = paramPane;
		} else {
			throw new Exception("The main panel can be set only ONCE!");
		} // Of if
	}// Of setPAne

}// Of class GUICommon

HelpDialog.java 显示帮助信息, 这样, 在主界面点击 Help 按钮时, 就会显示相关参数的说明. 其目的在于提高软件的易用性、可维护性.

package machinelearning.gui;

import java.io.*;
import java.awt.*;
import java.awt.event.*;

/**
 * Display the help message.
 * 
 * @author Fan Min [email protected].
 */
public class HelpDialog extends Dialog implements ActionListener {
	/**
	 * Serial uid. Not quite useful.
	 */
	private static final long serialVersionUID = 3869415040299264995L;

	/**
	 *************************** 
	 * Display the help dialog.
	 * 
	 * @param paraTitle
	 *            the title of the dialog.
	 * @param paraFilename
	 *            the help file.
	 *************************** 
	 */
	public HelpDialog(String paraTitle, String paraFilename) {
		super(GUICommon.mainFrame, paraTitle, true);
		setBackground(GUICommon.MY_COLOR);

		TextArea displayArea = new TextArea("", 10, 10, TextArea.SCROLLBARS_VERTICAL_ONLY);
		displayArea.setEditable(false);
		String textToDisplay = "";
		try {
			RandomAccessFile helpFile = new RandomAccessFile(paraFilename, "r");
			String tempLine = helpFile.readLine();
			while (tempLine != null) {
				textToDisplay = textToDisplay + tempLine + "\n";
				tempLine = helpFile.readLine();
			}
			helpFile.close();
		} catch (IOException ee) {
			dispose();
			ErrorDialog.errorDialog.setMessageAndShow(ee.toString());
		}
		// Use this if you need to display Chinese. Consult the author for this
		// method.
		// textToDisplay = SimpleTools.GB2312ToUNICODE(textToDisplay);
		displayArea.setText(textToDisplay);
		displayArea.setFont(new Font("Times New Romans", Font.PLAIN, 14));

		Button okButton = new Button("OK");
		okButton.setSize(20, 10);
		okButton.addActionListener(new DialogCloser(this));
		Panel okPanel = new Panel();
		okPanel.setLayout(new FlowLayout());
		okPanel.add(okButton);

		// OK Button
		setLayout(new BorderLayout());
		add(BorderLayout.CENTER, displayArea);
		add(BorderLayout.SOUTH, okPanel);

		setLocation(120, 70);
		setSize(500, 400);
		addWindowListener(new DialogCloser());
		setVisible(false);
	}// Of constructor

	/**
	 ************************* 
	 * Simply set it visible.
	 ************************* 
	 */
	public void actionPerformed(ActionEvent ee) {
		setVisible(true);
	}// Of actionPerformed.
}// Of class HelpDialog

help.txt 是 HelpDialog 要显示的内容, 当然要自己写. 这个文件与这里的 java 文件放在同一个文件夹.

The is the ANN project GUI.
1. The arff filename is for the data file.
2. Alpha, beta, and gamma are parameters for the activator.
3. Layer nodes should be some numbers separated by comma. For example, "4, 8, 8, 3" means that the number of input nodes is 4, the number of output nodes is 3, and the number of two hidden layer nodes are all 8.
4. Activators are specified by a string, each character for the activator of a layer.
5. Training rounds: more rounds for more stable results.
6. Learning rate
7. Mobp

第 78 天: GUI (2. 数据读取控件)

DoubleField.java 用于接受实型值, 如果不能解释成实型值会报错. 这样可以把用户的低级错误扼杀在摇篮中.

package machinelearning.gui;

import java.awt.*;
import java.awt.event.*;

/**
 * For the input of a double value.
 * 
 * @author Fan Min [email protected].
 */
public class DoubleField extends TextField implements FocusListener {

	/**
	 * Serial uid. Not quite useful.
	 */
	private static final long serialVersionUID = 363634723L;

	/**
	 * The value
	 */
	protected double doubleValue;

	/**
	 *************************** 
	 * Give it default values.
	 *************************** 
	 */
	public DoubleField() {
		this("5.13", 10);
	}// Of the first constructor

	/**
	 *************************** 
	 * Only specify the content.
	 * 
	 * @param paraString
	 *            The content of the field.
	 *************************** 
	 */
	public DoubleField(String paraString) {
		this(paraString, 10);
	}// Of the second constructor

	/**
	 *************************** 
	 * Only specify the width.
	 * 
	 * @param paraWidth
	 *            The width of the field.
	 *************************** 
	 */
	public DoubleField(int paraWidth) {
		this("5.13", paraWidth);
	}// Of the third constructor

	/**
	 *************************** 
	 * Specify the content and the width.
	 * 
	 * @param paraString
	 *            The content of the field.
	 * @param paraWidth
	 *            The width of the field.
	 *************************** 
	 */
	public DoubleField(String paraString, int paraWidth) {
		super(paraString, paraWidth);
		addFocusListener(this);
	}// Of the fourth constructor

	/**
	 ********************************** 
	 * Implement FocusListener.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void focusGained(FocusEvent paraEvent) {
	}// Of focusGained

	/**
	 ********************************** 
	 * Implement FocusListener.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void focusLost(FocusEvent paraEvent) {
		try {
			doubleValue = Double.parseDouble(getText());
		} catch (Exception ee) {
			ErrorDialog.errorDialog
					.setMessageAndShow("\"" + getText() + "\" Not a double. Please check.");
			requestFocus();
		} // Of try
	}// Of focusLost

	/**
	 ********************************** 
	 * Get the double value.
	 * 
	 * @return the double value.
	 ********************************** 
	 */
	public double getValue() {
		try {
			doubleValue = Double.parseDouble(getText());
		} catch (Exception ee) {
			ErrorDialog.errorDialog
					.setMessageAndShow("\"" + getText() + "\" Not a double. Please check.");
			requestFocus();
		} // Of try
		return doubleValue;
	}// Of getValue
}// Of class DoubleField

IntegerField.java 同理.

package machinelearning.gui;

import java.awt.*;
import java.awt.event.*;

/**
 * For the input of an int value.
 * @author Fan Min [email protected].
 */
public class IntegerField extends TextField implements FocusListener {

	/**
	 * Serial uid. Not quite useful.
	 */
	private static final long serialVersionUID = -2462338973265150779L;

	/**
	 *************************** 
	 * Only specify the content.
	 *************************** 
	 */
	public IntegerField() {
		this("513");
	}// Of constructor

	/**
	 *************************** 
	 * Specify the content and the width.
	 * 
	 * @param paraString
	 *            The default value of the content.
	 * @param paraWidth 
	 * The width of the field.
	 *************************** 
	 */
	public IntegerField(String paraString, int paraWidth) {
		super(paraString, paraWidth);
		addFocusListener(this);
	}// Of constructor

	/**
	 *************************** 
	 * Only specify the content.
	 * 
	 * @param paraString
	 *            The given default string.
	 *************************** 
	 */
	public IntegerField(String paraString) {
		super(paraString);
		addFocusListener(this);
	}// Of constructor

	/**
	 *************************** 
	 * Only specify the width.
	 * 
	 * @param paraWidth
	 *            The width of the field.
	 *************************** 
	 */
	public IntegerField(int paraWidth) {
		super(paraWidth);
		setText("513");
		addFocusListener(this);
	}// Of constructor

	/**
	 ********************************** 
	 * Implement FocusListenter.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void focusGained(FocusEvent paraEvent) {
	}// Of focusGained

	/**
	 ********************************** 
	 * Implement FocusListenter.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void focusLost(FocusEvent paraEvent) {
		try {
			Integer.parseInt(getText());
			// System.out.println(tempInt);
		} catch (Exception ee) {
			ErrorDialog.errorDialog.setMessageAndShow("\"" + getText()
					+ "\"Not an integer. Please check.");
			requestFocus();
		}
	}// Of focusLost

	/**
	 ********************************** 
	 * Get the int value. Show error message if the content is not an int.
	 * 
	 * @return the int value.
	 ********************************** 
	 */
	public int getValue() {
		int tempInt = 0;
		try {
			tempInt = Integer.parseInt(getText());
		} catch (Exception ee) {
			ErrorDialog.errorDialog.setMessageAndShow("\"" + getText()
					+ "\" Not an int. Please check.");
			requestFocus();
		}
		return tempInt;
	}// Of getValue

}// Of class IntegerField

FilenameField.java 则需要借助于系统提供的 FileDialog.

package machinelearning.gui;

import java.io.*;
import java.awt.*;
import java.awt.event.*;

/**
 * For the input of a filename.
 * @author Fan Min [email protected].
 */
public class FilenameField extends TextField implements ActionListener,
		FocusListener {
	/**
	 * Serial uid. Not quite useful.
	 */
	private static final long serialVersionUID = 4572287941606065298L;

	/**
	 *************************** 
	 * No special initialization..
	 *************************** 
	 */
	public FilenameField() {
		super();
		setText("");
		addFocusListener(this);
	}// Of constructor

	/**
	 *************************** 
	 * No special initialization.
	 * 
	 * @param paraWidth
	 *            The width of the .
	 *************************** 
	 */
	public FilenameField(int paraWidth) {
		super(paraWidth);
		setText("");
		addFocusListener(this);
	}// Of constructor

	/**
	 *************************** 
	 * No special initialization.
	 * 
	 * @param paraWidth
	 *            The width of the .
	 * @param paraText
	 *            The given initial text
	 *************************** 
	 */
	public FilenameField(int paraWidth, String paraText) {
		super(paraWidth);
		setText(paraText);
		addFocusListener(this);
	}// Of constructor

	/**
	 *************************** 
	 * No special initialization.
	 * 
	 * @param paraWidth
	 *            The width of the .
	 * @param paraText
	 *            The given initial text
	 *************************** 
	 */
	public FilenameField(String paraText, int paraWidth) {
		super(paraWidth);
		setText(paraText);
		addFocusListener(this);
	}// Of constructor

	/**
	 ********************************** 
	 * Avoid setting null or empty string.
	 * 
	 * @param paraText
	 *            The given text.
	 ********************************** 
	 */
	public void setText(String paraText) {
		if (paraText.trim().equals("")) {
			super.setText("unspecified");
		} else {
			super.setText(paraText.replace('\\', '/'));
		}//Of if
	}// Of setText

	/**
	 ********************************** 
	 * Implement ActionListenter.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void actionPerformed(ActionEvent paraEvent) {
		FileDialog tempDialog = new FileDialog(GUICommon.mainFrame,
				"Select a file");
		tempDialog.setVisible(true);
		if (tempDialog.getDirectory() == null) {
			setText("");
			return;
		}//Of if
		
		String directoryName = tempDialog.getDirectory();
		
		String tempFilename = directoryName + tempDialog.getFile(); 
		//System.out.println("tempFilename = " + tempFilename);

		setText(tempFilename);
	}// Of actionPerformed

	/**
	 ********************************** 
	 * Implement FocusListenter.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void focusGained(FocusEvent paraEvent) {
	}// Of focusGained

	/**
	 ********************************** 
	 * Implement FocusListenter.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void focusLost(FocusEvent paraEvent) {
		// System.out.println("Focus lost exists.");
		String tempString = getText();
		if ((tempString.equals("unspecified"))
				|| (tempString.equals("")))
			return;
		File tempFile = new File(tempString);
		if (!tempFile.exists()) {
			ErrorDialog.errorDialog.setMessageAndShow("File \"" + tempString
					+ "\" not exists. Please check.");
			requestFocus();
			setText("");
		}
	}// Of focusLost
}// Of class FilenameField

第 79 天: GUI (3. 总体布局)

今天终于可以把 GUI 抄完整了. 但实际上也只有今天的与项目相关.

  1. 用了 GridLayout 和 BorderLayout 来组织控件.
  2. 按下 OK 执行 actionPerformed. 前两天已经有类似代码了.
package machinelearning.gui;

import java.awt.*;
import java.awt.event.*;
import java.util.Date;

import machinelearning.ann.FullAnn;

/**
 * The main entrance of ANN GUI.
 * 
 * @author Fan Min [email protected].
 */
public class AnnMain implements ActionListener {
	/**
	 * Select the arff file.
	 */
	private FilenameField arffFilenameField;

	/**
	 * The setting of alpha.
	 */
	private DoubleField alphaField;

	/**
	 * The setting of alpha.
	 */
	private DoubleField betaField;

	/**
	 * The setting of alpha.
	 */
	private DoubleField gammaField;

	/**
	 * Layer nodes, such as "4, 8, 8, 3".
	 */
	private TextField layerNodesField;

	/**
	 * Activators, such as "ssa".
	 */
	private TextField activatorField;

	/**
	 * The number of training rounds.
	 */
	private IntegerField roundsField;

	/**
	 * The learning rate.
	 */
	private DoubleField learningRateField;

	/**
	 * The mobp.
	 */
	private DoubleField mobpField;

	/**
	 * The message area.
	 */
	private TextArea messageTextArea;

	/**
	 *************************** 
	 * The only constructor.
	 *************************** 
	 */
	public AnnMain() {
		// A simple frame to contain dialogs.
		Frame mainFrame = new Frame();
		mainFrame.setTitle("ANN. [email protected]");
		// The top part: select arff file.
		arffFilenameField = new FilenameField(30);
		arffFilenameField.setText("d:/data/iris.arff");
		Button browseButton = new Button(" Browse ");
		browseButton.addActionListener(arffFilenameField);

		Panel sourceFilePanel = new Panel();
		sourceFilePanel.add(new Label("The .arff file:"));
		sourceFilePanel.add(arffFilenameField);
		sourceFilePanel.add(browseButton);

		// Setting panel.
		Panel settingPanel = new Panel();
		settingPanel.setLayout(new GridLayout(3, 6));

		settingPanel.add(new Label("alpha"));
		alphaField = new DoubleField("0.01");
		settingPanel.add(alphaField);

		settingPanel.add(new Label("beta"));
		betaField = new DoubleField("0.02");
		settingPanel.add(betaField);

		settingPanel.add(new Label("gamma"));
		gammaField = new DoubleField("0.03");
		settingPanel.add(gammaField);

		settingPanel.add(new Label("layer nodes"));
		layerNodesField = new TextField("4, 8, 8, 3");
		settingPanel.add(layerNodesField);

		settingPanel.add(new Label("activators"));
		activatorField = new TextField("sss");
		settingPanel.add(activatorField);

		settingPanel.add(new Label("training rounds"));
		roundsField = new IntegerField("5000");
		settingPanel.add(roundsField);

		settingPanel.add(new Label("learning rate"));
		learningRateField = new DoubleField("0.01");
		settingPanel.add(learningRateField);

		settingPanel.add(new Label("mobp"));
		mobpField = new DoubleField("0.5");
		settingPanel.add(mobpField);

		Panel topPanel = new Panel();
		topPanel.setLayout(new BorderLayout());
		topPanel.add(BorderLayout.NORTH, sourceFilePanel);
		topPanel.add(BorderLayout.CENTER, settingPanel);

		messageTextArea = new TextArea(80, 40);

		// The bottom part: ok and exit
		Button okButton = new Button(" OK ");
		okButton.addActionListener(this);
		// DialogCloser dialogCloser = new DialogCloser(this);
		Button exitButton = new Button(" Exit ");
		// cancelButton.addActionListener(dialogCloser);
		exitButton.addActionListener(ApplicationShutdown.applicationShutdown);
		Button helpButton = new Button(" Help ");
		helpButton.setSize(20, 10);
		helpButton.addActionListener(new HelpDialog("ANN", "src/machinelearning/gui/help.txt"));
		Panel okPanel = new Panel();
		okPanel.add(okButton);
		okPanel.add(exitButton);
		okPanel.add(helpButton);

		mainFrame.setLayout(new BorderLayout());
		mainFrame.add(BorderLayout.NORTH, topPanel);
		mainFrame.add(BorderLayout.CENTER, messageTextArea);
		mainFrame.add(BorderLayout.SOUTH, okPanel);

		mainFrame.setSize(600, 500);
		mainFrame.setLocation(100, 100);
		mainFrame.addWindowListener(ApplicationShutdown.applicationShutdown);
		mainFrame.setBackground(GUICommon.MY_COLOR);
		mainFrame.setVisible(true);
	}// Of the constructor

	/**
	 *************************** 
	 * Read the arff file.
	 *************************** 
	 */
	public void actionPerformed(ActionEvent ae) {
		String tempFilename = arffFilenameField.getText();

		// Read the layers nodes.
		String tempString = layerNodesField.getText().trim();

		int[] tempLayerNodes = null;
		try {
			tempLayerNodes = stringToIntArray(tempString);
		} catch (Exception ee) {
			ErrorDialog.errorDialog.setMessageAndShow(ee.toString());
			return;
		} // Of try

		double tempLearningRate = learningRateField.getValue();
		double tempMobp = mobpField.getValue();
		String tempActivators = activatorField.getText().trim();
		FullAnn tempNetwork = new FullAnn(tempFilename, tempLayerNodes, tempLearningRate, tempMobp,
				tempActivators);
		int tempRounds = roundsField.getValue();

		long tempStartTime = new Date().getTime();
		for (int i = 0; i < tempRounds; i++) {
			tempNetwork.train();
		} // Of for n
		long tempEndTime = new Date().getTime();
		messageTextArea.append("\r\nSummary:\r\n");
		messageTextArea.append("Trainng time: " + (tempEndTime - tempStartTime) + "ms.\r\n");

		double tempAccuray = tempNetwork.test();
		messageTextArea.append("Accuracy: " + tempAccuray + "\r\n");
		messageTextArea.append("End.");
	}// Of actionPerformed

	/**
	 ********************************** 
	 * Convert a string with commas into an int array.
	 * 
	 * @param paraString
	 *            The source string
	 * @return An int array.
	 * @throws Exception
	 *             Exception for illegal data.
	 ********************************** 
	 */
	public static int[] stringToIntArray(String paraString) throws Exception {
		int tempCounter = 1;
		for (int i = 0; i < paraString.length(); i++) {
			if (paraString.charAt(i) == ',') {
				tempCounter++;
			} // Of if
		} // Of for i

		int[] resultArray = new int[tempCounter];

		String tempRemainingString = new String(paraString) + ",";
		String tempString;
		for (int i = 0; i < tempCounter; i++) {
			tempString = tempRemainingString.substring(0, tempRemainingString.indexOf(",")).trim();
			if (tempString.equals("")) {
				throw new Exception("Blank is unsupported");
			} // Of if

			resultArray[i] = Integer.parseInt(tempString);

			tempRemainingString = tempRemainingString
					.substring(tempRemainingString.indexOf(",") + 1);
		} // Of for i

		return resultArray;
	}// Of stringToIntArray

	/**
	 *************************** 
	 * The entrance method.
	 * 
	 * @param args
	 *            The parameters.
	 *************************** 
	 */
	public static void main(String args[]) {
		new AnnMain();
	}// Of main
}// Of class AnnMain

第 80 天: GUI (4. 接口与监听机制)

今天写一段程序来体会一下接口与监听机制.

  1. Day 1 写了一个 Flying 接口, 它不依赖于后面的代码.
  2. Day 2 写了一个 Controller 类, 它仅使用 Flying 接口.
  3. Day 3 写了一个 Bird 类, 它实现了 Flying 接口.
  4. Day 4 写了一个 Plane 类, 它也实现了 Flying 接口.
  5. Day 5 写了一个测试类.
  6. 神奇的事情出现了: Day 3 与 Day 4 的代码居然被 Day 2 的代码调用! 原因在于 Day 2, Day 3, Day 4 的代码都遵循了 Day 1 的接口. 调用者 (如 Controller) 与被调用者 (如 Bird) 可以分别做自己的事情, 只要遵循接口就可以互连互通, 这就是接口存在的意义.
  7. Bird 与 Plane 对接口有不同的实现方式, 还可以使用自己的成员变量, 它导致了多态. 多态实际上是继承带来的特性.
  8. 接口本质上支持多继承.
  9. 从监听机制、接口等角度, 分析在 GUI 上的各种操作分别会触发哪些代码;
  10. 总结基础的人工神经网络.
package machinelearning.gui;

/** 
 * Explain the interface and listener mechanism.
 * 
 * @author Fan Min [email protected].
 */
// Day 1. Define an interface. 
interface Flying{
	public void fly();
}//Of interface Flying

//Day 2. Define a controller to cope with it.
class Controller{
	Flying flying;
	
	Controller(){
		flying = null;
	}//Of the constructor
	
	void setListener(Flying paraFlying){
		flying = paraFlying;
	}//Of addListener
	
	void doIt(){
		flying.fly();
	}//Of doIt
}//Of Controller

//Day 3. Define class Bird for the interface.
class Bird implements Flying{
	double weight = 0.5;
	
	public void fly(){
		System.out.println("Bird fly, my weight is " + weight + " kg.");
	}//Of fly
}//Of class Bird

//Day 4. Define class Plane for the interface.
class Plane implements Flying{
	double price = 100000000;
	public void fly(){
		System.out.println("Plan fly, my price is " + price + " RMB.");
	}//Of fly
}//Of class Plane

//Day 5. Test the interface.
public class InterfaceTest {
	public static void main(String[] args){
		Controller tempController = new Controller();
		Flying tempFlying1 = new Bird();
		tempController.setListener(tempFlying1);
		tempController.doIt();
		
		Flying tempFlying2 = new Plane();
		tempController.setListener(tempFlying2);
		tempController.doIt();
	}//Of main
}//Of class InterfaceTest

你可能感兴趣的:(Java,程序设计基础,神经网络)