Joone 示例2

package com.pintn.joone;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInput;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Date;
import java.util.Vector;

import org.joone.engine.FullSynapse;
import org.joone.engine.LinearLayer;
import org.joone.engine.Monitor;
import org.joone.engine.NeuralNetEvent;
import org.joone.engine.NeuralNetListener;
import org.joone.engine.Pattern;
import org.joone.engine.SigmoidLayer;
import org.joone.engine.learning.TeachingSynapse;
import org.joone.io.MemoryInputSynapse;
import org.joone.io.MemoryOutputSynapse;
import org.joone.net.NeuralNet;

public class XOR_using_NeuralNet_RPROP implements NeuralNetListener,Serializable {
    private NeuralNet nnet = null; /* 初始化数据网络为空*/
    private MemoryInputSynapse  inputSynapse, desiredOutputSynapse; /*声明MemoryInputSynapse 的两个对象,输入神经元和 期望输出神经元为数组变量 */
    private MemoryOutputSynapse outputSynapse;
   
    // XOR input    双精度型二维数组 inputArray 并初始化  作为神经元的两个输入变量的数组   
    private double[][] inputArray = new double[][] {
        {0.0, 0.0},
        {0.0, 1.0},
        {1.03, 0.0},
        {0.98, 1.0}
    };
   
    // XOR desired output  双精度型二维数组 inputArray 并初始化  作为神经元的期望输出变量的数组  
    private double[][] desiredOutputArray = new double[][] {
        {0.0},
        {1.0},
        {0.103},
        {0.0}
    };
   
   
    public static void main(String args[]) {
XOR_using_NeuralNet_RPROP xor = new XOR_using_NeuralNet_RPROP();
       
       
        xor.initNeuralNet(); 
        xor.train();     
    }
   
     
    protected void initNeuralNet() {       //神经网络的初始化函数
       
        // First create the three layers
     LinearLayer  input = new LinearLayer();     
     SigmoidLayer hidden = new SigmoidLayer();   
     SigmoidLayer output = new SigmoidLayer();
       
        // set the dimensions of the layers
        input.setRows(2);   //   输入层神经元的个数为2   因为 异或(0,1)(1,1)(1,1)(0,0)
        hidden.setRows(3);  //隐含层的神经元个数为 3 这个要根据实际的经验选取
        output.setRows(1);  //输出层的神经元个数为 1   输出异或运算结果
       
        // Now create the two Synapses
        FullSynapse synapse_IH = new FullSynapse();
        FullSynapse synapse_HO = new FullSynapse();
       
        // Connect the input layer whit the hidden layer
        input.addOutputSynapse(synapse_IH); 
        hidden.addInputSynapse(synapse_IH);    
       
        // Connect the hidden layer whit the output layer
        hidden.addOutputSynapse(synapse_HO);
        output.addInputSynapse(synapse_HO);
       
        // the input to the neural net
        inputSynapse = new MemoryInputSynapse();
       
        input.addInputSynapse(inputSynapse);  //将输入神经元的输入作为神经网络 突触的输入变量
       
        // the output of the neural net
        outputSynapse = new MemoryOutputSynapse();
       
        output.addOutputSynapse(outputSynapse);//将输出神经元的输出作为神经网络突触的输入变量
       
        // The Trainer and its desired output
        desiredOutputSynapse = new MemoryInputSynapse();
       
        TeachingSynapse trainer = new TeachingSynapse();
       
        trainer.setDesired(desiredOutputSynapse);  //建立教师层,监视神经网络
       
        // Now we add this structure to a NeuralNet object//添加上面建立的神经网络的结构到神经网络
        nnet = new NeuralNet();
       
        nnet.addLayer(input, NeuralNet.INPUT_LAYER);
        nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);
        nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);
        nnet.setTeacher(trainer);
        output.addOutputSynapse(trainer);
    }
   
   
    public void train() {
       
        // set the inputs
        inputSynapse.setInputArray(inputArray);   //神经网络的输入数组为inputArray数组
        inputSynapse.setAdvancedColumnSelector("1,2");//选择inputArray数组  前两个变量为输入变量
       
        // set the desired outputs
      desiredOutputSynapse.setInputArray(desiredOutputArray);
        desiredOutputSynapse.setAdvancedColumnSelector("1");//选择outputArray数组前一个变量为输出变量
       
        // get the monitor object to train or feed forward   建立监控连接 和  误差的前向反馈
        Monitor monitor = nnet.getMonitor();  
       
        // set the monitor parameters
        monitor.setLearningRate(0.8); //设定神经网络的学习率,
        monitor.setMomentum(0.3);//设定神经网络的动量 为 0.3 这两个变量与步长有关
        monitor.setTrainingPatterns(inputArray.length);//训练的模式为输入数组的长度
        monitor.setTotCicles(5000);     //设定总训练次数为5000
        // RPROP parameters
//        monitor.getl
//        monitor.getLearner().add(0, "org.joone.engine.RpropLearner");
//        monitor.getLearners().add(0, "org.joone.engine.RpropLearner");
        monitor.setBatchSize(7);//修改
        monitor.setLearningMode(0);
       
        monitor.setLearning(true);//待学习的训练监控 
        nnet.addNeuralNetListener(this);//监听神经网络
        nnet.start();   //数据网路开始训练
        nnet.getMonitor().Go();//神经网络开始监控,并运行
    }
   
   
    public void cicleTerminated(NeuralNetEvent e) {
System.out.println("cicleTerminated...");//训练终结
    }
   
    public void errorChanged(NeuralNetEvent e) {
  // System.out.println("error...");
         Monitor mon = (Monitor)e.getSource();//得到监控层的信息
         System.out.println("Cycle: "+(mon.getTotCicles()-mon.getCurrentCicle())+" RMSE:"+mon.getGlobalError());    //输出  训练的次数和  rmse 均方误差
   }
   
    public void netStarted(NeuralNetEvent e) {
System.out.println("Training...");
    }
   
    public void netStopped(NeuralNetEvent e) {
  System.out.println("Training Stopped...");
 

  System.out.println(5000);
        Monitor mon = (Monitor)e.getSource();
        // Read the last pattern and print it out
        Vector patts = outputSynapse.getAllPatterns();
        Pattern pattern = (Pattern)patts.elementAt(patts.size() - 1);
        System.out.println("Output Pattern = " + pattern.getArray()[0] + " Error: " + mon.getGlobalError());   //输出最后的输出的结果和RMSE值  
        Date Mydate =new Date();  //初始化一个时间对象
        long mystr = System.currentTimeMillis(); //初始化当前的系统时间
        System.out.println(mystr); 
       
        saveNeuralNet("d:"+ mystr +"myxor.snet"); //保存生成当前时间的myxor.snet神经网络
    }
public void saveNeuralNet(String fileName) {
try {
FileOutputStream stream = new FileOutputStream(fileName);
ObjectOutputStream out = new ObjectOutputStream(stream);
out.writeObject(nnet);//写入nnet对象
out.close();
}
catch (Exception excp) {
excp.printStackTrace();
}
}




    private NeuralNet restoreNeuralNet(String fileName) {
        NeuralNet nnet = null;  
        try {
            FileInputStream stream = new FileInputStream(fileName);
            ObjectInput input = new ObjectInputStream(stream);
            nnet = (NeuralNet)input.readObject();
        }
        catch (Exception e) {
            System.out.println( "Exception was thrown. Message is : " + e.getMessage());
        }
        return nnet;
    }
 
   
   
    public void netStoppedError(NeuralNetEvent e, String error) {
    }
   
}

你可能感兴趣的:(Joone)