这是一个感知器学习AND规则的例子, 抄自NEU的Sample
首先定义TrainingSet, 两个参数分别表示输入向量SIZE和输出向量SIZE,看代码:
/**
* Creates an instance of new empty training set
*
* @param inputVectorSize
* @param outputVectorSize
*/
public TrainingSet(int inputVectorSize, int outputVectorSize) {
this.elements = new Vector<TrainingElement>();
this.inputVectorSize = inputVectorSize;
this.outputVectorSize = outputVectorSize;
}
然后定义一个感知器, 两个参数分别表示输入层的元数和输出的元数, 代码:
/**
* Creates new Perceptron with specified number of neurons in input and
* output layer, with Step trqansfer function
*
* @param inputNeuronsCount
* number of neurons in input layer
* @param outputNeuronsCount
* number of neurons in output layer
*/
public Perceptron(int inputNeuronsCount, int outputNeuronsCount) {
this.createNetwork(inputNeuronsCount, outputNeuronsCount, TransferFunctionType.STEP);
}
然后在当前线程启动学习过程。 对于需要异步执行的,提供了learnInNewThread来满足需求。
训练网络
验证输出。
public void testPerceptron() {
TrainingSet trainingSet = new TrainingSet(2, 1);
trainingSet.addElement(new SupervisedTrainingElement(new double[]{0, 0}, new double[]{0}));
trainingSet.addElement(new SupervisedTrainingElement(new double[]{0, 1}, new double[]{0}));
trainingSet.addElement(new SupervisedTrainingElement(new double[]{1, 0}, new double[]{0}));
trainingSet.addElement(new SupervisedTrainingElement(new double[]{1, 1}, new double[]{1}));
NeuralNetwork myPerceptron = new Perceptron(2, 1);
myPerceptron.learnInSameThread(trainingSet);
System.out.println("Testing trained perceptron");
neuralNetwork(myPerceptron, trainingSet);
myPerceptron.save("mySamplePerceptron.nnet");
NeuralNetwork loadedPerceptron = NeuralNetwork.load("mySamplePerceptron.nnet");
System.out.println("Testing loaded perceptron");
neuralNetwork(loadedPerceptron, trainingSet);
}
private void neuralNetwork(NeuralNetwork neuralNet, TrainingSet trainingSet) {
for(TrainingElement trainingElement : trainingSet.trainingElements()) {
neuralNet.setInput(trainingElement.getInput());
neuralNet.calculate();
Vector<Double> networkOutput = neuralNet.getOutput();
double d1 = trainingElement.getInput().get(0);
double d2 = trainingElement.getInput().get(1);
double result = networkOutput.get(0);
boolean b1 = d1 == 1.0;
boolean b2 = d2 == 1.0;
boolean r1 = result == 1.0;
assertEquals(b1 & b2, r1);
}
}