前段时间学习NN时使用到了BPNN,考虑到模型的分布式扩展,我想到使用Mahout的MultiLayer Perceptron(mlp)来实现。于是下载研读了Mahout中该模块的源码
,这会儿希望能把学习笔记记录下来,一来怕自己后面遗忘,二来与大伙儿一同学习。
这里我使用的Mahout版本是0.10,直接因为Apache貌似在Mahout0.11版本中删去了mlp板块(反正我是没找到。。。。)
模块路径:mr.src.main.java.org.apache.mahout.classifer.mlp
该模块路径下存放有5个.java文件,分别是:
其中每一个文件下都定义了一个以该文件名命名的类,MultilayerPerceptron是NeuralNetwork的子类,后者是整个NN模块的核心,
NeuralNetworkFunctions专门定义实现了NN模块中用到的数学计算公式;最后两个文件则是分别封装了NN模块的训练过程(TrainMultilayerPerceptron)
和预测过程(RunMultilayerPerceptron),这里我们主要学习NeuralNetwork类及其实现
NeuralNetwork类中,包含了多个参量和成员方法,这里列举其中一些主要的:
Mahout神经网络模块主要成员变量及获取/配置方法
成员变量 | 获取方法 | 配置方法 |
LearningRate | getLearningRate() | setLearningRate() |
MomentumWeight | getMomentumWeight() | setMomentumWeight() |
RegularizationWeight | getRegularizationWeight() | setRegularizationWeight() |
TrainingMethod | getTrainingMethod() | setTrainingMethod() |
CostFunction | getCostFunction() | setCostFunction() |
Mahout神经网络模块主要成员方法及描述
成员方法 | 描述 |
addLayer(int size, boolean isFinalLayer, String squashingFuctionName) | 为神经网络模型添加新的网络层,其中参数“size”表示当前层下的神经元个数;参数“isFinalLayer”表示是否当前层级为神经网络的最后一层;参数“squashingFunctionName”则表示当前层级下的激励函数(又称挤压函数) |
trainOnline(Vector trainingInstance) | 在线训练模型,输入参数为输入特征与实际输出特征形成的向量。 |
getOutput(Vectoe instance) | 计算模型输出,输入参数为输入特征与实际输出特征形成的向量 |
setModelPath(String modelPath) | 设置模型路径为:modelPath |
writeModelToFile() | 将模型写入已指定的modelPath下 |
trainOnline()方法实现了模型训练过程,看一下它的内部:
public void trainOnline(Vector trainingInstance) { Matrix[] matrices = trainByInstance(trainingInstance); updateWeightMatrices(matrices); }即先执行trainByInstance(),将结果传入matrices,再执行updateWeightMatrices(matrices),下面来到trainByInstance:
public Matrix[] trainByInstance(Vector trainingInstance) {
// validate training instance
int inputDimension = layerSizeList.get(0) - 1;
int outputDimension = layerSizeList.get(this.layerSizeList.size() - 1);
Preconditions.checkArgument(inputDimension + outputDimension == trainingInstance.size(),
String.format("The dimension of training instance is %d, but requires %d.", trainingInstance.size(),
inputDimension + outputDimension));
if (trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) {
return trainByInstanceGradientDescent(trainingInstance);
}
throw new IllegalArgumentException("Training method is not supported.");
}
在这个方法中,输入参数trainingInstance 的维数等于输入特征维数与输出特征维数之和,函数接收参数后,首先根据类成员变量layerSizeList找到神经
网络中第一层(输入层)节点数与最后一层(输出层)节点数,此后checkArgument()执行判断确保输入特征与输出特征之和等于传入的参数特征数,经过这一步骤后的训练样本,以GRADIENT_DESCENT训练方法被模型训练(目前该模块仅支持这一种训练方法),返回trainByInstanceGradientDescent():
private Matrix[] trainByInstanceGradientDescent(Vector trainingInstance) {
int inputDimension = layerSizeList.get(0) - 1;
Vector inputInstance = new DenseVector(layerSizeList.get(0));
inputInstance.set(0, 1); // add bias
for (int i = 0; i < inputDimension; ++i) {
inputInstance.set(i + 1, trainingInstance.get(i));
}
Vector labels =
trainingInstance.viewPart(inputInstance.size() - 1, trainingInstance.size() - inputInstance.size() + 1);
// initialize weight update matrices
Matrix[] weightUpdateMatrices = new Matrix[weightMatrixList.size()];
for (int m = 0; m < weightUpdateMatrices.length; ++m) {
weightUpdateMatrices[m] =
new DenseMatrix(weightMatrixList.get(m).rowSize(), weightMatrixList.get(m).columnSize());
}
List internalResults = getOutputInternal(inputInstance);
Vector deltaVec = new DenseVector(layerSizeList.get(layerSizeList.size() - 1));
Vector output = internalResults.get(internalResults.size() - 1);
final DoubleFunction derivativeSquashingFunction =
NeuralNetworkFunctions.getDerivativeDoubleFunction(squashingFunctionList.get(squashingFunctionList.size() - 1));
final DoubleDoubleFunction costFunction =
NeuralNetworkFunctions.getDerivativeDoubleDoubleFunction(costFunctionName);
Matrix lastWeightMatrix = weightMatrixList.get(weightMatrixList.size() - 1);
for (int i = 0; i < deltaVec.size(); ++i) {
double costFuncDerivative = costFunction.apply(labels.get(i), output.get(i + 1));
// Add regularization
costFuncDerivative += regularizationWeight * lastWeightMatrix.viewRow(i).zSum();
deltaVec.set(i, costFuncDerivative);
deltaVec.set(i, deltaVec.get(i) * derivativeSquashingFunction.apply(output.get(i + 1)));
}
// Start from previous layer of output layer
for (int layer = layerSizeList.size() - 2; layer >= 0; --layer) {
deltaVec = backPropagate(layer, deltaVec, internalResults, weightUpdateMatrices[layer]);
}
prevWeightUpdatesList = Arrays.asList(weightUpdateMatrices);
return weightUpdateMatrices;
}
这一部分代码相对较多,我们逐块分析:
首先该方法解析输入参量,将输入特征和输出特征分离后分别写入inputInstance和labels,之后初始化一个weightUpdateMatrices,
然后通过getOutputInternal()方法获得输出,将输出值的输入特征和输出特征分别写入deltaVal 和output;分别获取当前网络层级下的
derivativeSquashingFunction(新建NN实例时就定义好了的)、costFuction(新建NN实例时就定义好了的)以及lastWeightMatrices(初始化weightUpdateMatrices时定义的)
在这之后,逐个依据每一位的labels和output计算costFuncDerivative(默认为MSE),再分别考虑regularizationWeight和SquashingFuction,得到最终的deltaVec.
完成这一步后,将此时得到的deltaVec与之前的各层网络做误差反向传播(backPropagate()方法),以此更新deltaVec,最终返回跟新后的weightUpdataMatrices
private Vector backPropagate(int currentLayerIndex, Vector nextLayerDelta,
List outputCache, Matrix weightUpdateMatrix) {
// Get layer related information
final DoubleFunction derivativeSquashingFunction =
NeuralNetworkFunctions.getDerivativeDoubleFunction(squashingFunctionList.get(currentLayerIndex));
Vector curLayerOutput = outputCache.get(currentLayerIndex);
Matrix weightMatrix = weightMatrixList.get(currentLayerIndex);
Matrix prevWeightMatrix = prevWeightUpdatesList.get(currentLayerIndex);
// Next layer is not output layer, remove the delta of bias neuron
if (currentLayerIndex != layerSizeList.size() - 2) {
nextLayerDelta = nextLayerDelta.viewPart(1, nextLayerDelta.size() - 1);
}
Vector delta = weightMatrix.transpose().times(nextLayerDelta);
delta = delta.assign(curLayerOutput, new DoubleDoubleFunction() {
@Override
public double apply(double deltaElem, double curLayerOutputElem) {
return deltaElem * derivativeSquashingFunction.apply(curLayerOutputElem);
}
});
// Update weights
for (int i = 0; i < weightUpdateMatrix.rowSize(); ++i) {
for (int j = 0; j < weightUpdateMatrix.columnSize(); ++j) {
weightUpdateMatrix.set(i, j, -learningRate * nextLayerDelta.get(i) *
curLayerOutput.get(j) + this.momentumWeight * prevWeightMatrix.get(i, j));
}
}
return delta;
}
以上为mlp中核心算法的实现,其中上文未提及的一些方法实现例如如何计算costFuncDerivative、如何使用SquashingFuction以及如何backPropagate等,大家可以查阅
NN相关书籍资料,这里的实现与书籍上介绍的算法完全一致,因此不再赘述。这里想要说明的是关于模型的序列化和反序列化过程,因为这一步骤是一个模型进行分布式扩展的必要步骤:
在mlp模块中,模型的序列化和反序列化通过write()和readFields()方法来实现,源码如下:
public void write(DataOutput output) throws IOException {
// Write model type
WritableUtils.writeString(output, modelType);
// Write learning rate
output.writeDouble(learningRate);
// Write model path
if (modelPath != null) {
WritableUtils.writeString(output, modelPath);
} else {
WritableUtils.writeString(output, "null");
}
// Write regularization weight
output.writeDouble(regularizationWeight);
// Write momentum weight
output.writeDouble(momentumWeight);
// Write cost function
WritableUtils.writeString(output, costFunctionName);
// Write layer size list
output.writeInt(layerSizeList.size());
for (Integer aLayerSizeList : layerSizeList) {
output.writeInt(aLayerSizeList);
}
WritableUtils.writeEnum(output, trainingMethod);
// Write squashing functions
output.writeInt(squashingFunctionList.size());
for (String aSquashingFunctionList : squashingFunctionList) {
WritableUtils.writeString(output, aSquashingFunctionList);
}
// Write weight matrices
output.writeInt(this.weightMatrixList.size());
for (Matrix aWeightMatrixList : weightMatrixList) {
MatrixWritable.writeMatrix(output, aWeightMatrixList);
}
}
/**
* Read the fields of the model from input.
*
* @param input The input instance.
* @throws IOException
*/
public void readFields(DataInput input) throws IOException {
// Read model type
modelType = WritableUtils.readString(input);
if (!modelType.equals(this.getClass().getSimpleName())) {
throw new IllegalArgumentException("The specified location does not contains the valid NeuralNetwork model.");
}
// Read learning rate
learningRate = input.readDouble();
// Read model path
modelPath = WritableUtils.readString(input);
if (modelPath.equals("null")) {
modelPath = null;
}
// Read regularization weight
regularizationWeight = input.readDouble();
// Read momentum weight
momentumWeight = input.readDouble();
// Read cost function
costFunctionName = WritableUtils.readString(input);
// Read layer size list
int numLayers = input.readInt();
layerSizeList = new ArrayList<>();
for (int i = 0; i < numLayers; i++) {
layerSizeList.add(input.readInt());
}
trainingMethod = WritableUtils.readEnum(input, TrainingMethod.class);
// Read squash functions
int squashingFunctionSize = input.readInt();
squashingFunctionList = new ArrayList<>();
for (int i = 0; i < squashingFunctionSize; i++) {
squashingFunctionList.add(WritableUtils.readString(input));
}
// Read weights and construct matrices of previous updates
int numOfMatrices = input.readInt();
weightMatrixList = new ArrayList<>();
prevWeightUpdatesList = new ArrayList<>();
for (int i = 0; i < numOfMatrices; i++) {
Matrix matrix = MatrixWritable.readMatrix(input);
weightMatrixList.add(matrix);
prevWeightUpdatesList.add(new DenseMatrix(matrix.rowSize(), matrix.columnSize()));
}
}