用java写bp神经网络(四)

接上篇。

在(一)和(二)中,程序的体系是Net,Propagation,Trainer,Learner,DataProvider。这篇重构这个体系。

Net

首先是Net,在上篇重新定义了激活函数和误差函数后,内容大致是这样的:

List<DoubleMatrix> weights = new ArrayList<DoubleMatrix>();

	List<DoubleMatrix> bs = new ArrayList<>();

	List<ActivationFunction> activations = new ArrayList<>();

	CostFunction costFunc;

	CostFunction accuracyFunc;

	int[] nodesNum;

	int layersNum;



public CompactDoubleMatrix getCompact(){

		return new CompactDoubleMatrix(this.weights,this.bs);

	}

 函数getCompact()生成对应的超矩阵。

DataProvider

DataProvider是数据的提供者。

public interface DataProvider {

    DoubleMatrix getInput();

    DoubleMatrix getTarget();

}

 如果输入为向量,还包含一个向量字典。

public interface DictDataProvider extends DataProvider {

	public DoubleMatrix getIndexs();

	public DoubleMatrix getDict();

}

 每一列为一个样本。getIndexs()返回输入向量在字典中的索引。

我写了一个有用的类BatchDataProviderFactory来对样本进行批量分割,分割成minibatch。

int batchSize;

	int dataLen;

	DataProvider originalProvider;

	List<Integer> endPositions;

	List<DataProvider> providers;



	public BatchDataProviderFactory(int batchSize, DataProvider originalProvider) {

		super();

		this.batchSize = batchSize;

		this.originalProvider = originalProvider;

		this.dataLen = this.originalProvider.getTarget().columns;

		this.initEndPositions();

		this.initProviders();

	}



	public BatchDataProviderFactory(DataProvider originalProvider) {

		this(4, originalProvider);

	}



	public List<DataProvider> getProviders() {

		return providers;

	}

 batchSize指明要分多少批,getProviders返回生成的minibatch,被分的原始数据为originalProvider。

Propagation

Propagation负责对神经网络的正向传播过程和反向传播过程。接口定义如下:

public interface Propagation {

	public PropagationResult propagate(Net net,DataProvider provider);

}

 传播函数propagate用指定数据对指定网络进行传播操作,返回执行结果。

BasePropagation实现了该接口,实现了简单的反向传播:

public class BasePropagation implements Propagation{



	// 多个样本。

	protected ForwardResult forward(Net net,DoubleMatrix input) {

		

		ForwardResult result = new ForwardResult();

		result.input = input;

		DoubleMatrix currentResult = input;

		int index = -1;

		for (DoubleMatrix weight : net.weights) {

			index++;

			DoubleMatrix b = net.bs.get(index);

			final ActivationFunction activation = net.activations

					.get(index);

			currentResult = weight.mmul(currentResult).addColumnVector(b);

			result.netResult.add(currentResult);



			// 乘以导数

			DoubleMatrix derivative = activation.derivativeAt(currentResult);

			result.derivativeResult.add(derivative);

			

			currentResult = activation.valueAt(currentResult);

			result.finalResult.add(currentResult);



		}



		result.netResult=null;// 不再需要。

		

		return result;

	}



	



    // 多个样本梯度平均值。

	protected BackwardResult backward(Net net,DoubleMatrix target,

			ForwardResult forwardResult) {

		BackwardResult result = new BackwardResult();

		

		DoubleMatrix output = forwardResult.getOutput();

		DoubleMatrix outputDerivative = forwardResult.getOutputDerivative();

		

		result.cost = net.costFunc.valueAt(output, target);

		DoubleMatrix outputDelta = net.costFunc.derivativeAt(output, target).muli(outputDerivative);

		if (net.accuracyFunc != null) {

			result.accuracy=net.accuracyFunc.valueAt(output, target);

		}



		result.deltas.add(outputDelta);

		for (int i = net.layersNum - 1; i >= 0; i--) {

			DoubleMatrix pdelta = result.deltas.get(result.deltas.size() - 1);



			// 梯度计算,取所有样本平均

			DoubleMatrix layerInput = i == 0 ? forwardResult.input

					: forwardResult.finalResult.get(i - 1);

			DoubleMatrix gradient = pdelta.mmul(layerInput.transpose()).div(

					target.columns);

			result.gradients.add(gradient);

			// 偏置梯度

			result.biasGradients.add(pdelta.rowMeans());



			// 计算前一层delta,若i=0,delta为输入层误差,即input调整梯度,不作平均处理。

			DoubleMatrix delta = net.weights.get(i).transpose().mmul(pdelta);

			if (i > 0)

				delta = delta.muli(forwardResult.derivativeResult.get(i - 1));

			result.deltas.add(delta);

		}

		Collections.reverse(result.gradients);

		Collections.reverse(result.biasGradients);

		

		//其它的delta都不需要。

		DoubleMatrix inputDeltas=result.deltas.get(result.deltas.size()-1);

		result.deltas.clear();

		result.deltas.add(inputDeltas);

		

		return result;

	}



	@Override

	public PropagationResult propagate(Net net, DataProvider provider) {

		ForwardResult forwardResult=this.forward(net, provider.getInput());

		BackwardResult backwardResult=this.backward(net, provider.getTarget(), forwardResult);

		PropagationResult result=new PropagationResult(backwardResult);

		result.output=forwardResult.getOutput();

		return result;

	}

 我们定义的PropagationResult略为:

public class PropagationResult{

		DoubleMatrix output;// 输出结果矩阵:outputLen*sampleLength

		DoubleMatrix cost;// 误差矩阵:1*sampleLength

		DoubleMatrix accuracy;// 准确度矩阵:1*sampleLength

		private List<DoubleMatrix> gradients;// 权重梯度矩阵

		private List<DoubleMatrix> biasGradients;// 偏置梯度矩阵

		DoubleMatrix inputDeltas;//输入层delta矩阵:inputLen*sampleLength

		

		public CompactDoubleMatrix getCompact(){

			return new CompactDoubleMatrix(gradients,biasGradients);

		}

		

	}

 另一个实现了该接口的类为MiniBatchPropagation。他在内部用并行方式对样本进行传播,然后对每个minipatch结果进行综合,内部用到了BatchDataProviderFactory类和BasePropagation类。

Trainer

Trainer接口定义为:

public interface Trainer {

    public void train(Net net,DataProvider provider);

}

简单的实现类为:

public class CommonTrainer implements Trainer {

	int ecophs;

	Learner learner;

	Propagation propagation;

	List<Double> costs = new ArrayList<>();

	List<Double> accuracys = new ArrayList<>();

	public void trainOne(Net net, DataProvider provider) {

		PropagationResult propResult = this.propagation

				.propagate(net, provider);

		learner.learn(net, propResult, provider);



		Double cost = propResult.getMeanCost();

		Double accuracy = propResult.getMeanAccuracy();

		if (cost != null)

			costs.add(cost);

		if (accuracy != null)

			accuracys.add(accuracy);

	}



	@Override

	public void train(Net net, DataProvider provider) {

		for (int i = 0; i < this.ecophs; i++) {

			System.out.println("echops:"+i);

			this.trainOne(net, provider);

		}



	}

}

 简单的迭代echops此,没有智能停止功能,每次迭代用Learner调节权重。

Learner

Learner根据每次传播结果对网络权重进行调整,接口定义如下:

public interface Learner<N extends Net,P extends DataProvider> {

    public void learn(N net,PropagationResult propResult,P provider);

}

 一个简单的根据动量因子-自适应学习率进行调整的实现类为:

public class MomentAdaptLearner<N extends Net, P extends DataProvider>

		implements Learner<N, P> {

	double moment = 0.7;

	double lmd = 1.05;

	double preCost = 0;

	double eta = 0.01;

	double currentEta = eta;

	double currentMoment = moment;

	CompactDoubleMatrix preGradient;



	public MomentAdaptLearner(double moment, double eta) {

		super();

		this.moment = moment;

		this.eta = eta;

		this.currentEta = eta;

		this.currentMoment = moment;

	}



	public MomentAdaptLearner() {



	}



	@Override

	public void learn(N net, PropagationResult propResult, P provider) {

		if (this.preGradient == null)

			init(net, propResult, provider);



		double cost = propResult.getMeanCost();

		this.modifyParameter(cost);

		System.out.println("current eta:" + this.currentEta);

		System.out.println("current moment:" + this.currentMoment);

		this.updateGradient(net, propResult, provider);



	}



	public void updateGradient(N net, PropagationResult propResult, P provider) {

		CompactDoubleMatrix netCompact = this.getNetCompact(net, propResult,

				provider);

		CompactDoubleMatrix gradCompact = this.getGradientCompact(net,

				propResult, provider);

		gradCompact = gradCompact.mul(currentEta * (1 - currentMoment)).addi(

				preGradient.mul(currentMoment));

		netCompact.subi(gradCompact);

		this.preGradient = gradCompact;

	}



	public CompactDoubleMatrix getNetCompact(N net,

			PropagationResult propResult, P provider) {

		return net.getCompact();

	}



	public CompactDoubleMatrix getGradientCompact(N net,

			PropagationResult propResult, P provider) {

		return propResult.getCompact();

	}



	public void modifyParameter(double cost) {



		if (this.currentEta > 10) {

			this.currentEta = 10;

		} else if (this.currentEta < 0.0001) {

			this.currentEta = 0.0001;

		} else if (cost < this.preCost) {

			this.currentEta *= 1.05;

			this.currentMoment = moment;

		} else if (cost < 1.04 * this.preCost) {

			this.currentEta *= 0.7;

			this.currentMoment *= 0.7;

		} else {

			this.currentEta = eta;

			this.currentMoment = 0.1;

		}

		this.preCost = cost;

	}



	public void init(Net net, PropagationResult propResult, P provider) {

		PropagationResult pResult = new PropagationResult(net);

		preGradient = pResult.getCompact().dup();

	}



}

 在上面的代码中,我们可以看到CompactDoubleMatrix类对权重自变量的封装,使代码更加简洁,它在此表现出来的就是一个超矩阵,超向量,完全忽略了内部的结构。

同时,其子类实现了同步更新字典的功能,代码也很简洁,只是简单的把需要调整的矩阵append到超矩阵中去即可,在父类中会统一对其进行调整:

public class DictMomentLearner extends

		MomentAdaptLearner<Net, DictDataProvider> {



	public DictMomentLearner(double moment, double eta) {

		super(moment, eta);

	}



	public DictMomentLearner() {

		super();

	}



	@Override

	public CompactDoubleMatrix getNetCompact(Net net,

			PropagationResult propResult, DictDataProvider provider) {

		CompactDoubleMatrix result = super.getNetCompact(net, propResult,

				provider);

		result.append(provider.getDict());

		return result;

	}



	@Override

	public CompactDoubleMatrix getGradientCompact(Net net,

			PropagationResult propResult, DictDataProvider provider) {

		CompactDoubleMatrix result = super.getGradientCompact(net, propResult,

				provider);

		result.append(DictUtil.getDictGradient(provider, propResult));

		return result;

	}



	@Override

	public void init(Net net, PropagationResult propResult,

			DictDataProvider provider) {

		DoubleMatrix preDictGradient = DoubleMatrix.zeros(

				provider.getDict().rows, provider.getDict().columns);

		super.init(net, propResult, provider);

		this.preGradient.append(preDictGradient);

	}

}

 

你可能感兴趣的:(java)