日撸代码300行:第72天(固定激活函数的BP神经网络)

 代码来自闵老师”日撸 Java 三百行(71-80天):

日撸 Java 三百行(71-80天,BP 神经网络)_闵帆的博客-CSDN博客

扩展抽象类GeneralAnn,必须要实现对应的类内的抽象方法。

核心代码就是正向和方向传播两个方法的实现。

Mobp:momentum coefficient(动量系数).  SGD通常来说下降速度比较快,但却容易造成另一个问题,就是更新过程不稳定,容易出现震荡。加入“惯性”的影响,就是在更新下降方向的时候不仅要考虑到当前的方向,也要考虑到上一次的更新方向。两者加权,某些情况下可以避免震荡,摆脱局部凹域的束缚,进入全局凹域。动量,就是上一次更新方向所占的权值。当误差曲面中存在平坦区域,SGD可以更快的学习,是梯度下降法中一种常用的加速技术。

Learning rate:学习率决定了权值更新的速度,设置得太大会使结果超过最优值,太小会使下降速度过慢。

package machinelearning.ann;

/**
 * Back-propagation neural networks.
 * 
 * @author WX873
 *
 */
public class SimpleAnn extends GeneralAnn{
	
	/**
	 * The value of each node that changes during the forward process. The first
	 * dimension stands for the layer, and the second stands for the node.
	 */
	public double[][] layerNodeValues;
	
	/**
	 * The error on each node that changes during the back-propagation process.
	 * The first dimension stands for the layer, and the second stands for the nodes.
	 */
	public double[][] layerNodeErrors;
	
	/**
	 * The weights of edges. The first dimension stands for the layer, the
	 * second stands for the node index of the layer, and the third dimension
	 * stands for the node index of the next layer.
	 */
	public double[][][] edgeWeights;
	
	/**
	 * The change of edge weights. It has the same size as edgeWeights.
	 */
	public double[][][] edgeWeightsDelta;
		
	/**
	 * ************************************************************************
	 * The first constructor.
	 * 
	 * @param paraFilename           The arff filename.
	 * @param paraLayerNumNodes      The number of nodes for each layer (may be different).
	 * @param paraLearningRate       Learning rate.
	 * @param paraMobp               Momentum coefficient.
	 * ************************************************************************
	 */
	public SimpleAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp) {
		super(paraFilename, paraLayerNumNodes, paraLearningRate, paraMobp);
		
		// Step 1. Across layer initialization.
		layerNodeValues = new double[numLayers][];
		layerNodeErrors = new double[numLayers][];
		edgeWeights = new double[numLayers - 1][][];
		edgeWeightsDelta = new double[numLayers - 1][][];
		
		// Step 2. Inner layer initialization.
		for (int l = 0; l < numLayers; l++) {
			layerNodeValues[l] = new double[layerNumNodes[l]];
			layerNodeErrors[l] = new double[layerNumNodes[l]];
			
			// One less layer because each edge crosses two layers.
			if (l + 1 == numLayers) {
				break;
			}//of if
			
			// In layerNumNodes[l] + 1, the last one is reserved for the offset.(偏置)
			edgeWeights[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]]; 
			//第一维是本层节点加一个偏置,第二维是下一层的节点个数
			edgeWeightsDelta[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];
			
			for (int j = 0; j < layerNumNodes[l] + 1; j++) {
				for (int i = 0; i < layerNumNodes[l + 1]; i++) {
					// Initialize weights.
					edgeWeights[l][j][i] = random.nextDouble();
				}//of for i
			}//of for j
		}//of for l
		
	}// Of the constructor

	/**
	 * ***********************************************************
	 * Forward prediction.
	 * 
	 * @param paraInput   The input data of one instance.
	 * @return   The data at the output end.
	 * ***********************************************************
	 */
	public double[] forward(double[] paraInput) {
		// Initialize the input layer.
		for (int i = 0; i < layerNodeValues[0].length; i++) {
			layerNodeValues[0][i] = paraInput[i];
		}//of for i
		
		// Calculate the node values of each layer.
		double z;
		for (int l = 1; l < numLayers; l++) {
			for (int j = 0; j < layerNodeValues[l].length; j++) {
				// Initialize according to the offset, which is always +1
				//(l-1)层的第(该层节点个数)个节点(偏置)指向下一层[j]节点的值;z等于该边的权值。
				z = edgeWeights[l - 1][layerNodeValues[l - 1].length][j];
				
				// Weighted sum on all edges for this node.
				for (int i = 0; i < layerNodeValues[l - 1].length; i++) {
					z += edgeWeights[l - 1][i][j] * layerNodeValues[l - 1][i];
				}//of for i (循环上一层的节点)
				
				// Sigmoid activation.
				// This line should be changed for other activation functions.
				layerNodeValues[l][j] = 1 / (1 + Math.exp(-z));
			}//of for j(循环要计算的节点的当前层节点)
		}//of for l
		return layerNodeValues[numLayers - 1];
	}//of forward

	/**
	 * ******************************************************
	 * Back propagation and change the edge weights.
	 * 
	 * @param paraTarget  For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
	 * ******************************************************
	 */
	public void backPropagation(double[] paraTarget) {
		// Step 1. Initialize the output layer error.
		int l = numLayers - 1;   //第l层就是输出层
		for (int j = 0; j < layerNodeErrors[l].length; j++) {
			layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) 
					* (paraTarget[j] - layerNodeValues[l][j]); //套用输出层误差公式
		}//of for j
		
		// Step 2. Back-propagation even for l == 0
		while (l > 0) {
			l--;
			// Layer l, for each node.
			for (int j = 0; j < layerNumNodes[l]; j++) {
				double z = 0.0;
				// For each node of the next layer.
				for (int i = 0; i < layerNumNodes[l + 1]; i++) {
					if (l > 0) {
						z += layerNodeErrors[l + 1][i] * edgeWeights[l][j][i];
						//(l+1)层的第i个节点,乘以l层第j个节点指向下一层第i个节点的边的权重。
					}//of if
					
					// Weight adjusting.
					edgeWeightsDelta[l][j][i] = mobp * edgeWeightsDelta[l][j][i] + 
							learningRate * layerNodeErrors[l + 1][i] * layerNodeValues[l][j];
					edgeWeights[l][j][i] += edgeWeightsDelta[l][j][i];
					if (j == layerNumNodes[l] - 1) {
						// Weight adjusting for the offset part.
						edgeWeightsDelta[l][j + 1][i] = mobp * edgeWeightsDelta[l][j + 1][i]
								+ learningRate * layerNodeErrors[l + 1][i];
						edgeWeights[l][j + 1][i] += edgeWeightsDelta[l][j + 1][i];
					}//of if
				}//of for i
				
				// Record the error according to the differential of Sigmoid.
				// This line should be changed for other activation functions.
				layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * z;
			}//of for j
		}//of while
	}//of backPropagation
	
	public static void main(String args[]) {
		int[] tempLayerNodes = {4, 8, 8, 3};
		SimpleAnn tempNetwork = new SimpleAnn("E:/Datasets/UCIdatasets/其他数据集/iris.arff", 
				tempLayerNodes, 0.01, 0.6);
		
		for (int round = 0; round < 5000; round++) {
			System.out.println("The round is: " + round);
			tempNetwork.train();
		}//of for round
		
		double tempAccuracy = tempNetwork.test();
		System.out.println("The accuracy is: " + tempAccuracy);
	}//of main

}//of SimpleAnn

你可能感兴趣的:(算法,机器学习,java)