Java学习(Day 37)

学习来源:日撸 Java 三百行(81-90天,CNN 卷积神经网络)_闵帆的博客-CSDN博客

文章目录

  • 前言
  • 卷积神经网络 (代码篇)
    • 一、数据集读取与存储
      • 1. 数据集描述
      • 2. 具体代码
      • 3. 运行截图
    • 二、卷积核大小的基本操作
      • 1. 操作
      • 2. 具体代码
      • 3. 运行截图
    • 三、数学工具类
      • 1. 工具函数
      • 2. 具体代码
    • 四、网络结构与参数
    • 五、神经网络的搭建
      • 1. 正向传播
      • 2. 反向传播
      • 4. 具体代码
      • 5. 运行截图
  • 总结

前言

本文代码来自 CSDN文章: 日撸 Java 三百行(81-90天,CNN 卷积神经网络)

我将借用这部分代码对 CNN 进行一个更深层次的理解.

卷积神经网络 (代码篇)

一、数据集读取与存储

1. 数据集描述

简要描述一下我们需要读取的数据集.

0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0

乍一看这不就是由 0 和 1组成的集合吗? 这个时候我们对这些数字想象成一个图片, 然后通过一些工具就可以呈现出下面的这样一副图片.

Java学习(Day 37)_第1张图片

这张图片的大小就为 28 × 28 28 \times 28 28×28, 那这堆数据最后不是多出了一个数字吗? 这个数字要表达什么意思呢? 这个时候仔细观察图片, 它是不是看起来像数字 ‘0’. 为了检验这个想法是否正确, 我们再找一行数据.

0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3

Java学习(Day 37)_第2张图片

虽然图中的数字写法不标准, 但是隐约中还是能判别为数字 ‘3’, 然后多出的那个数字正好是 ‘3’. 由此得出结论, 数据集的每一行代表一张图片, 由 ‘0’ ‘1’ 表示其黑白像素点, 且该行最后一个数字表示图片中数字的值.

所以对于这个数据集数据的读取就是把图片的像素点以数组方式存储, 数组的大小就是图片的大小. 然后用一个单独的值存储图片中所表示的数字, 把这个就作为图片的标签.

2. 具体代码

package cnn;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Manage the dataset.
 *
 * @author Shi-Huai Wen Email: [email protected].
 */
public class Dataset {
    /**
     * All instances organized by a list.
     */
    private List<Instance> instances;

    /**
     * The label index.
     */
    private int labelIndex;

    /**
     * The max label (label start from 0).
     */
    private double maxLabel = -1;

    /**
     * **********************
     * The first constructor.
     * **********************
     */
    public Dataset() {
        labelIndex = -1;
        instances = new ArrayList<>();
    }// Of the first constructor

    /**
     * **********************
     * The second constructor.
     *
     * @param paraFilename   The filename.
     * @param paraSplitSign  Often comma.
     * @param paraLabelIndex Often the last column.
     * **********************
     */
    public Dataset(String paraFilename, String paraSplitSign, int paraLabelIndex) {
        instances = new ArrayList<>();
        labelIndex = paraLabelIndex;

        File tempFile = new File(paraFilename);
        try {
            BufferedReader tempReader = new BufferedReader(new FileReader(tempFile));
            String tempLine;
            while ((tempLine = tempReader.readLine()) != null) {
                String[] tempDatum = tempLine.split(paraSplitSign);
                if (tempDatum.length == 0) {
                    continue;
                } // Of if

                double[] tempData = new double[tempDatum.length];
                for (int i = 0; i < tempDatum.length; i++)
                    tempData[i] = Double.parseDouble(tempDatum[i]);
                Instance tempInstance = new Instance(tempData);
                append(tempInstance);
            } // Of while
            tempReader.close();
        } catch (IOException e) {
            e.printStackTrace();
            System.out.println("Unable to load " + paraFilename);
            System.exit(0);
        }//Of try
    }// Of the second constructor

    /**
     * **********************
     * Append an instance.
     *
     * @param paraInstance The given record.
     * **********************
     */
    public void append(Instance paraInstance) {
        instances.add(paraInstance);
    }// Of append

    /**
     * **********************
     * Append an instance  specified by double values.
     * **********************
     */
    public void append(double[] paraAttributes, Double paraLabel) {
        instances.add(new Instance(paraAttributes, paraLabel));
    }// Of append

    /**
     * **********************
     * Getter.
     * **********************
     */
    public Instance getInstance(int paraIndex) {
        return instances.get(paraIndex);
    }// Of getInstance

    /**
     * **********************
     * Getter.
     * **********************
     */
    public int size() {
        return instances.size();
    }// Of size

    /**
     * **********************
     * Getter.
     * **********************
     */
    public double[] getAttributes(int paraIndex) {
        return instances.get(paraIndex).getAttributes();
    }// Of getAttrs

    /**
     * **********************
     * Getter.
     * **********************
     */
    public Double getLabel(int paraIndex) {
        return instances.get(paraIndex).getLabel();
    }// Of getLabel

    /**
     * **********************
     * Unit test.
     * **********************
     */
    public static void main(String[] args) {
        Dataset tempData = new Dataset("D:/Work/Data/sampledata/train.format", ",", 784);
        Instance tempInstance = tempData.getInstance(0);
        System.out.println("The first instance is: " + tempInstance);
        System.out.println("The first instance label is: " + tempInstance.label);

        tempInstance = tempData.getInstance(1);
        System.out.println("The second instance is: " + tempInstance);
        System.out.println("The second instance label is: " + tempInstance.label);
    }// Of main

    /**
     * **********************
     * An instance.
     * **********************
     */
    public class Instance {
        /**
         * Conditional attributes.
         */
        private double[] attributes;

        /**
         * Label.
         */
        private Double label;

        /**
         * **********************
         * The first constructor.
         * **********************
         */
        private Instance(double[] paraAttrs, Double paraLabel) {
            attributes = paraAttrs;
            label = paraLabel;
        }//Of the first constructor

        /**
         * **********************
         * The second constructor.
         * **********************
         */
        public Instance(double[] paraData) {
            if (labelIndex == -1) {
                // No label
                attributes = paraData;
            } else {
                label = paraData[labelIndex];
                if (label > maxLabel) {
                    // It is a new label
                    maxLabel = label;
                } // Of if

                if (labelIndex == 0) {
                    // The first column is the label
                    attributes = Arrays.copyOfRange(paraData, 1, paraData.length);
                } else {
                    // The last column is the label
                    attributes = Arrays.copyOfRange(paraData, 0, paraData.length - 1);
                } // Of if
            } // Of if
        }// Of the second constructor

        /**
         * **********************
         * Getter.
         * **********************
         */
        public double[] getAttributes() {
            return attributes;
        }// Of getAttributes

        /**
         * **********************
         * Getter.
         * **********************
         */
        public Double getLabel() {
            if (labelIndex == -1)
                return null;
            return label;
        }// Of getLabel

        /**
         * **********************
         * toString.
         * **********************
         */
        public String toString() {
            return Arrays.toString(attributes) + ", " + label;
        }//Of toString
    }// Of class Instance
} //Of class Dataset

3. 运行截图

Java学习(Day 37)_第3张图片

二、卷积核大小的基本操作

1. 操作

对卷积核大小进行处理, 也就是对卷积核的长和宽进行处理.

一个方法是长和宽同时除以两个整数, 要是不能被整除就抛出错误. 例如:

(4, 12) / (2, 3) -> (2, 4)
(2,  2) / (4, 6) -> Error

另一个方法是长和宽同时减去两个整数, 然后再加上 1. 例如:

(4, 6) - (2, 2) + 1 -> (3,5)

2. 具体代码

package cnn;

/**
 * The size of a convolution core.
 *
 * @author Shi-Huai Wen Email: [email protected].
 */
public class Size {
    /**
     * Cannot be changed after initialization.
     */
    public final int width;

    /**
     * Cannot be changed after initialization.
     */
    public final int height;

    /**
     * **********************
     * The first constructor.
     *
     * @param paraWidth  The given width.
     * @param paraHeight The given height.
     * **********************
     */
    public Size(int paraWidth, int paraHeight) {
        width = paraWidth;
        height = paraHeight;
    }// Of the first constructor

    /**
     * **********************
     * Divide a scale with another one. For example (4, 12) / (2, 3) = (2, 4).
     *
     * @param paraScaleSize The given scale size.
     * @return The new size.
     * **********************
     */
    public Size divide(Size paraScaleSize) {
        int resultWidth = width / paraScaleSize.width;
        int resultHeight = height / paraScaleSize.height;
        if (resultWidth * paraScaleSize.width != width || resultHeight * paraScaleSize.height != height) {
            throw new RuntimeException("Unable to divide " + this + " with " + paraScaleSize);
        }
        return new Size(resultWidth, resultHeight);
    }// Of divide

    /**
     * **********************
     * Subtract a scale with another one, and add a value. For example (4, 12) -
     * (2, 3) + 1 = (3, 10).
     *
     * @param paraScaleSize The given scale size.
     * @param paraAppend    The appended size to both dimensions.
     * @return The new size.
     * **********************
     */
    public Size subtract(Size paraScaleSize, int paraAppend) {
        int resultWidth = width - paraScaleSize.width + paraAppend;
        int resultHeight = height - paraScaleSize.height + paraAppend;
        return new Size(resultWidth, resultHeight);
    }// Of subtract


    public String toString() {
        String resultString = "(" + width + ", " + height + ")";
        return resultString;
    }// Of toString

    /**
     * **********************
     * Unit test.
     * **********************
     */
    public static void main(String[] args) {
        Size tempSize1 = new Size(4, 6);
        Size tempSize2 = new Size(2, 2);
        System.out.println("" + tempSize1 + " divide " + tempSize2 + " = " + tempSize1.divide(tempSize2));

        try {
            System.out.println("" + tempSize2 + " divide " + tempSize1 + " = " + tempSize2.divide(tempSize1));
        } catch (Exception ee) {
            System.out.println("Error is :" + ee);
        } // Of try

        System.out.println("" + tempSize1 + " - " + tempSize2 + " + 1 = " + tempSize1.subtract(tempSize2, 1));
    }// Of main
} //Of class Size

3. 运行截图

三、数学工具类

1. 工具函数

定义了一个算子, 其主要目的是为了矩阵操作时对每个元素都做一遍. 有对单个矩阵进行运算, 例如用 1 减去矩阵中的值, 或者对矩阵中的值使用 S i g m o i d Sigmoid Sigmoid 函数. 有对两个矩阵进行运算, 例如两个矩阵之间的加法还有减法.

矩阵旋转 180 度, 其实就是旋转两次 90 度. 旋转 90 度的公式为
m a t r i x [ r o w ] [ c o l ] = r o t a t e m a t r i x n e w [ c o l ] [ n − r o w − 1 ] matrix[row][col] \overset{rotate}{=}matrix_{new}[col][n - row - 1] matrix[row][col]=rotatematrixnew[col][nrow1]

convnValid 是卷积操作. convnFull 为其逆向操作.

scaleMatrix 是均值池化. kronecker 是池化的逆向操作.

2. 具体代码

package cnn;

import java.io.Serializable;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;

/**
 * Math operations. Adopted from cnn-master.
 *
 * @author Shi-Huai Wen Email: [email protected].
 */
public class MathUtils {
    /**
     * An interface for different on-demand operators.
     */
    public interface Operator extends Serializable {
        double process(double value);
    }// Of interface Operator

    /**
     * The one-minus-the-value operator.
     */
    public static final Operator one_value = new Operator() {
        private static final long serialVersionUID = 3752139491940330714L;

        @Override
        public double process(double value) {
            return 1 - value;
        }// Of process
    };

    /**
     * The sigmoid operator.
     */
    public static final Operator sigmoid = new Operator() {
        private static final long serialVersionUID = -1952718905019847589L;

        @Override
        public double process(double value) {
            return 1 / (1 + Math.pow(Math.E, -value));
        }// Of process
    };

    /**
     * An interface for operations with two operators.
     */
    interface OperatorOnTwo extends Serializable {
        double process(double a, double b);
    }// Of interface OperatorOnTwo

    /**
     * Plus.
     */
    public static final OperatorOnTwo plus = new OperatorOnTwo() {
        private static final long serialVersionUID = -6298144029766839945L;

        @Override
        public double process(double a, double b) {
            return a + b;
        }// Of process
    };

    /**
     * Multiply.
     */
    public static OperatorOnTwo multiply = new OperatorOnTwo() {

        private static final long serialVersionUID = -7053767821858820698L;

        @Override
        public double process(double a, double b) {
            return a * b;
        }// Of process
    };

    /**
     * Minus.
     */
    public static OperatorOnTwo minus = new OperatorOnTwo() {

        private static final long serialVersionUID = 7346065545555093912L;

        @Override
        public double process(double a, double b) {
            return a - b;
        }// Of process
    };

    /**
     * **********************
     * Print a matrix
     * **********************
     */
    public static void printMatrix(double[][] matrix) {
        for (double[] array : matrix) {
            String line = Arrays.toString(array);
            line = line.replaceAll(", ", "\t");
            System.out.println(line);
        } // Of for i
        System.out.println();
    }// Of printMatrix

    /**
     * **********************
     * Clone a matrix. Do not use it reference directly.
     * **********************
     */
    public static double[][] cloneMatrix(final double[][] matrix) {
        final int m = matrix.length;
        int n = matrix[0].length;
        final double[][] outMatrix = new double[m][n];

        for (int i = 0; i < m; i++) {
            System.arraycopy(matrix[i], 0, outMatrix[i], 0, n);
        } // Of for i
        return outMatrix;
    }// Of cloneMatrix

    /**
     * **********************
     * Rotate the matrix 180 degrees.
     * **********************
     */
    public static double[][] rot180(double[][] matrix) {
        matrix = cloneMatrix(matrix);
        int m = matrix.length;
        int n = matrix[0].length;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n / 2; j++) {
                double tmp = matrix[i][j];
                matrix[i][j] = matrix[i][n - 1 - j];
                matrix[i][n - 1 - j] = tmp;
            }
        }
        for (int j = 0; j < n; j++) {
            for (int i = 0; i < m / 2; i++) {
                double tmp = matrix[i][j];
                matrix[i][j] = matrix[m - 1 - i][j];
                matrix[m - 1 - i][j] = tmp;
            }
        }
        return matrix;
    }// Of rot180

    private static final Random myRandom = new Random(2);

    /**
     * **********************
     * Generate a random matrix with the given size. Each value takes value in
     * [-0.005, 0.095].
     * **********************
     */
    public static double[][] randomMatrix(int x, int y, boolean b) {
        double[][] matrix = new double[x][y];
        // int tag = 1;
        for (int i = 0; i < x; i++) {
            for (int j = 0; j < y; j++) {
                matrix[i][j] = (myRandom.nextDouble() - 0.05) / 10;
            } // Of for j
        } // Of for i
        return matrix;
    }// Of randomMatrix

    /**
     * **********************
     * Generate a random array with the given length. Each value takes value in
     * [-0.005, 0.095].
     * **********************
     */
    public static double[] randomArray(int len) {
        double[] data = new double[len];
        for (int i = 0; i < len; i++) {
            //data[i] = myRandom.nextDouble() / 10 - 0.05;
            data[i] = 0;
        } // Of for i
        return data;
    }// Of randomArray

    /**
     * **********************
     * Generate a random perm with the batch size.
     * **********************
     */
    public static int[] randomPerm(int size, int batchSize) {
        Set<Integer> set = new HashSet<>();
        while (set.size() < batchSize) {
            set.add(myRandom.nextInt(size));
        }
        int[] randPerm = new int[batchSize];
        int i = 0;
        for (Integer value : set) {
            randPerm[i++] = value;
        }
        return randPerm;
    }// Of randomPerm

    /**
     * **********************
     * Matrix operation with the given operator on single operand.
     * **********************
     */
    public static double[][] matrixOp(final double[][] ma, Operator operator) {
        final int m = ma.length;
        int n = ma[0].length;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                ma[i][j] = operator.process(ma[i][j]);
            } // Of for j
        } // Of for i
        return ma;
    }// Of matrixOp

    /**
     * **********************
     * Matrix operation with the given operator on two operands.
     * **********************
     */
    public static double[][] matrixOp(final double[][] ma, final double[][] mb,
                                      final Operator operatorA, final Operator operatorB, OperatorOnTwo operator) {
        final int m = ma.length;
        int n = ma[0].length;
        if (m != mb.length || n != mb[0].length)
            throw new RuntimeException("ma.length:" + ma.length + "  mb.length:" + mb.length);

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                double a = ma[i][j];
                if (operatorA != null) {
                    a = operatorA.process(a);
                }

                double b = mb[i][j];
                if (operatorB != null) {
                    b = operatorB.process(b);
                }

                mb[i][j] = operator.process(a, b);
            } // Of for j
        } // Of for i
        return mb;
    }// Of matrixOp

    /**
     * **********************
     * Extend the matrix to a bigger one (a number of times).
     * **********************
     */
    public static double[][] kronecker(final double[][] matrix, final Size scale) {
        final int m = matrix.length;
        int n = matrix[0].length;
        final double[][] outMatrix = new double[m * scale.width][n * scale.height];

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                for (int ki = i * scale.width; ki < (i + 1) * scale.width; ki++) {
                    for (int kj = j * scale.height; kj < (j + 1) * scale.height; kj++) {
                        outMatrix[ki][kj] = matrix[i][j];
                    }
                }
            }
        }
        return outMatrix;
    }// Of kronecker

    /**
     * **********************
     * Scale the matrix.
     * **********************
     */
    public static double[][] scaleMatrix(final double[][] matrix, final Size scale) {
        int m = matrix.length;
        int n = matrix[0].length;
        final int sm = m / scale.width;
        final int sn = n / scale.height;
        final double[][] outMatrix = new double[sm][sn];
        if (sm * scale.width != m || sn * scale.height != n)
            throw new RuntimeException("scale matrix");
        final int size = scale.width * scale.height;
        for (int i = 0; i < sm; i++) {
            for (int j = 0; j < sn; j++) {
                double sum = 0.0;
                for (int si = i * scale.width; si < (i + 1) * scale.width; si++) {
                    for (int sj = j * scale.height; sj < (j + 1) * scale.height; sj++) {
                        sum += matrix[si][sj];
                    } // Of for sj
                } // Of for si
                outMatrix[i][j] = sum / size;
            } // Of for j
        } // Of for i
        return outMatrix;
    }// Of scaleMatrix

    /**
     * **********************
     * Convolution full to obtain a bigger size. It is used in back-propagation.
     * **********************
     */
    public static double[][] convnFull(double[][] matrix, final double[][] kernel) {
        int m = matrix.length;
        int n = matrix[0].length;
        final int km = kernel.length;
        final int kn = kernel[0].length;
        final double[][] extendMatrix = new double[m + 2 * (km - 1)][n + 2 * (kn - 1)];
        for (int i = 0; i < m; i++) {
            System.arraycopy(matrix[i], 0, extendMatrix[i + km - 1], kn - 1, n);
        } // Of for i
        return convnValid(extendMatrix, kernel);
    }// Of convnFull

    /**
     * **********************
     * Convolution operation, from a given matrix and a kernel, sliding and sum
     * to obtain the result matrix. It is used in forward.
     * **********************
     */
    public static double[][] convnValid(final double[][] matrix, double[][] kernel) {
        int m = matrix.length;
        int n = matrix[0].length;
        final int km = kernel.length;
        final int kn = kernel[0].length;
        int kns = n - kn + 1;
        final int kms = m - km + 1;
        final double[][] outMatrix = new double[kms][kns];

        for (int i = 0; i < kms; i++) {
            for (int j = 0; j < kns; j++) {
                double sum = 0.0;
                for (int ki = 0; ki < km; ki++) {
                    for (int kj = 0; kj < kn; kj++)
                        sum += matrix[i + ki][j + kj] * kernel[ki][kj];
                }
                outMatrix[i][j] = sum;

            }
        }
        return outMatrix;
    }// Of convnValid

    /**
     * **********************
     * Convolution on a tensor.
     * **********************
     */
    public static double[][] convnValid(final double[][][][] matrix, int mapNoX, double[][][][] kernel, int mapNoY) {
        int m = matrix.length;
        int n = matrix[0][mapNoX].length;
        int h = matrix[0][mapNoX][0].length;
        int km = kernel.length;
        int kn = kernel[0][mapNoY].length;
        int kh = kernel[0][mapNoY][0].length;
        int kms = m - km + 1;
        int kns = n - kn + 1;
        int khs = h - kh + 1;
        if (matrix.length != kernel.length)
            throw new RuntimeException("length");
        final double[][][] outMatrix = new double[kms][kns][khs];
        for (int i = 0; i < kms; i++) {
            for (int j = 0; j < kns; j++)
                for (int k = 0; k < khs; k++) {
                    double sum = 0.0;
                    for (int ki = 0; ki < km; ki++) {
                        for (int kj = 0; kj < kn; kj++)
                            for (int kk = 0; kk < kh; kk++) {
                                sum += matrix[i + ki][mapNoX][j + kj][k + kk] * kernel[ki][mapNoY][kj][kk];
                            }
                    }
                    outMatrix[i][j][k] = sum;
                }
        }
        return outMatrix[0];
    }// Of convnValid

    /**
     * **********************
     * The sigmoid operation.
     * **********************
     */
    public static double sigmoid(double x) {
        return 1 / (1 + Math.pow(Math.E, -x));
    }// Of sigmoid

    /**
     * **********************
     * Sum all values of a matrix.
     * **********************
     */
    public static double sum(double[][] error) {
        int n = error[0].length;
        double sum = 0.0;
        for (double[] array : error) {
            for (int i = 0; i < n; i++) {
                sum += array[i];
            }
        }
        return sum;
    }// Of sum

    /**
     * **********************
     * Ad hoc sum.
     * **********************
     */
    public static double[][] sum(double[][][][] errors, int j) {
        int m = errors[0][j].length;
        int n = errors[0][j][0].length;
        double[][] result = new double[m][n];
        for (int mi = 0; mi < m; mi++) {
            for (int nj = 0; nj < n; nj++) {
                double sum = 0;
                for (double[][][] error : errors) {
                    sum += error[j][mi][nj];
                }
                result[mi][nj] = sum;
            }
        }
        return result;
    }// Of sum

    /**
     * **********************
     * Get the index of the maximal value for the final classification.
     * **********************
     */
    public static int getMaxIndex(double[] out) {
        double max = out[0];
        int index = 0;
        for (int i = 1; i < out.length; i++)
            if (out[i] > max) {
                max = out[i];
                index = i;
            }
        return index;
    }// Of getMaxIndex
} //Of class MathUtils

这里定义了一个枚举类用来标识每一层的属性, 比如输入层, 卷积层等.

package cnn;

/**
 * Enumerate all layer types.
 *
 * @author Shi-Huai Wen Email: [email protected].
 */
public enum LayerTypeEnum {
    INPUT, CONVOLUTION, SAMPLING, OUTPUT;
} //Of enum LayerTypeEnum

四、网络结构与参数

对单层设置一些工具类的函数, 然后就是通过上面的枚举类型 LayerTypeEnum 来区别神经网络中不同的层, 例如输入层、卷积层和池化层.

package cnn;

/**
 * One layer, support all four layer types. The code mainly initializes, gets,
 * and sets variables. Essentially no algorithm is implemented.
 *
 * @author Shi-Huai Wen Email: [email protected].
 */
public class CnnLayer {
    /**
     * The type of the layer.
     */
    LayerTypeEnum type;

    /**
     * The number of out map.
     */
    int outMapNum;

    /**
     * The map size.
     */
    Size mapSize;

    /**
     * The kernel size.
     */
    Size kernelSize;

    /**
     * The scale size.
     */
    Size scaleSize;

    /**
     * The index of the class (label) attribute.
     */
    int classNum = -1;

    /**
     * Kernel. Dimensions: [front map][out map][width][height].
     */
    private double[][][][] kernel;

    /**
     * Bias. The length is outMapNum.
     */
    private double[] bias;

    /**
     * Out maps. Dimensions:
     * [batchSize][outMapNum][mapSize.width][mapSize.height].
     */
    private double[][][][] outMaps;

    /**
     * Errors.
     */
    private double[][][][] errors;

    /**
     * For batch processing.
     */
    private static int recordInBatch = 0;

    /**
     * **********************
     * The first constructor.
     *
     * @param paraType Describe which Layer
     * @param paraNum  When the type is CONVOLUTION, it is the out map number. when
     *                 the type is OUTPUT, it is the class number.
     * @param paraSize When the type is INPUT, it is the map size; when the type is
     *                 CONVOLUTION, it is the kernel size; when the type is SAMPLING,
     *                 it is the scale size.
     * **********************
     */
    public CnnLayer(LayerTypeEnum paraType, int paraNum, Size paraSize) {
        type = paraType;
        switch (type) {
            case INPUT:
                outMapNum = 1;
                mapSize = paraSize; // No deep copy.
                break;
            case CONVOLUTION:
                outMapNum = paraNum;
                kernelSize = paraSize;
                break;
            case SAMPLING:
                scaleSize = paraSize;
                break;
            case OUTPUT:
                classNum = paraNum;
                mapSize = new Size(1, 1);
                outMapNum = classNum;
                break;
            default:
                System.out.println("Internal error occurred in AbstractLayer.java constructor.");
        }// Of switch
    }// Of the first constructor

    /**
     * **********************
     * Initialize the kernel.
     *
     * @param paraFrontMapNum When the type is CONVOLUTION, it is the out map number. when
     * **********************
     */
    public void initKernel(int paraFrontMapNum) {
        kernel = new double[paraFrontMapNum][outMapNum][][];
        for (int i = 0; i < paraFrontMapNum; i++) {
            for (int j = 0; j < outMapNum; j++) {
                kernel[i][j] = MathUtils.randomMatrix(kernelSize.width, kernelSize.height, true);
            } // Of for j
        } // Of for i
    }// Of initKernel

    /**
     * **********************
     * Initialize the output kernel. The code is revised to invoke initKernel(int).
     * **********************
     */
    public void initOutputKernel(int paraFrontMapNum, Size paraSize) {
        kernelSize = paraSize;
        initKernel(paraFrontMapNum);
    }// Of initOutputKernel

    /**
     * **********************
     * Initialize the bias. No parameter. "int frontMapNum" is claimed however not used.
     * **********************
     */
    public void initBias() {
        bias = MathUtils.randomArray(outMapNum);
    }// Of initBias

    /**
     * **********************
     * Initialize the errors.
     *
     * @param paraBatchSize The batch size.
     * **********************
     */
    public void initErrors(int paraBatchSize) {
        errors = new double[paraBatchSize][outMapNum][mapSize.width][mapSize.height];
    }// Of initErrors

    /**
     * **********************
     * Initialize out maps.
     *
     * @param paraBatchSize The batch size.
     * **********************
     */
    public void initOutMaps(int paraBatchSize) {
        outMaps = new double[paraBatchSize][outMapNum][mapSize.width][mapSize.height];
    }// Of initOutMaps

    /**
     * **********************
     * Prepare for a new batch.
     * **********************
     */
    public static void prepareForNewBatch() {
        recordInBatch = 0;
    }// Of prepareForNewBatch

    /**
     * **********************
     * Prepare for a new record.
     * **********************
     */
    public static void prepareForNewRecord() {
        recordInBatch++;
    }// Of prepareForNewRecord

    /**
     * **********************
     * Set one value of outMaps.
     * **********************
     */
    public void setMapValue(int paraMapNo, int paraX, int paraY, double paraValue) {
        outMaps[recordInBatch][paraMapNo][paraX][paraY] = paraValue;
    }// Of setMapValue

    /**
     * **********************
     * Set values of the whole map.
     * **********************
     */
    public void setMapValue(int paraMapNo, double[][] paraOutMatrix) {
        outMaps[recordInBatch][paraMapNo] = paraOutMatrix;
    }// Of setMapValue

    /**
     * **********************
     * Getter.
     * **********************
     */
    public Size getMapSize() {
        return mapSize;
    }// Of getMapSize

    /**
     * **********************
     * Setter.
     * **********************
     */
    public void setMapSize(Size paraMapSize) {
        mapSize = paraMapSize;
    }// Of setMapSize

    /**
     * **********************
     * Getter.
     * **********************
     */
    public LayerTypeEnum getType() {
        return type;
    }// Of getType

    /**
     * **********************
     * Getter.
     * **********************
     */
    public int getOutMapNum() {
        return outMapNum;
    }// Of getOutMapNum

    /**
     * **********************
     * Setter.
     * **********************
     */
    public void setOutMapNum(int paraOutMapNum) {
        outMapNum = paraOutMapNum;
    }// Of setOutMapNum

    /**
     * **********************
     * Getter.
     * **********************
     */
    public Size getKernelSize() {
        return kernelSize;
    }// Of getKernelSize

    /**
     * **********************
     * Getter.
     * **********************
     */
    public Size getScaleSize() {
        return scaleSize;
    }// Of getScaleSize

    /**
     * **********************
     * Getter.
     * **********************
     */
    public double[][] getMap(int paraIndex) {
        return outMaps[recordInBatch][paraIndex];
    }// Of getMap

    /**
     * **********************
     * Getter.
     * **********************
     */
    public double[][] getKernel(int paraFrontMap, int paraOutMap) {
        return kernel[paraFrontMap][paraOutMap];
    }// Of getKernel

    /**
     * **********************
     * Setter. Set one error.
     * **********************
     */
    public void setError(int paraMapNo, int paraMapX, int paraMapY, double paraValue) {
        errors[recordInBatch][paraMapNo][paraMapX][paraMapY] = paraValue;
    }// Of setError

    /**
     * **********************
     * Setter. Set one error matrix.
     * **********************
     */
    public void setError(int paraMapNo, double[][] paraMatrix) {
        errors[recordInBatch][paraMapNo] = paraMatrix;
    }// Of setError

    /**
     * **********************
     * Getter. Get one error matrix.
     * **********************
     */
    public double[][] getError(int paraMapNo) {
        return errors[recordInBatch][paraMapNo];
    }// Of getError

    /**
     * **********************
     * Getter. Get the whole error tensor.
     * **********************
     */
    public double[][][][] getErrors() {
        return errors;
    }// Of getErrors

    /**
     * **********************
     * Setter. Set one kernel.
     * **********************
     */
    public void setKernel(int paraLastMapNo, int paraMapNo, double[][] paraKernel) {
        kernel[paraLastMapNo][paraMapNo] = paraKernel;
    }// Of setKernel

    /**
     * **********************
     * Getter.
     * **********************
     */
    public double getBias(int paraMapNo) {
        return bias[paraMapNo];
    }// Of getBias

    /**
     * **********************
     * Setter.
     * **********************
     */
    public void setBias(int paraMapNo, double paraValue) {
        bias[paraMapNo] = paraValue;
    }// Of setBias

    /**
     * **********************
     * Getter.
     * **********************
     */
    public double[][][][] getMaps() {
        return outMaps;
    }// Of getMaps

    /**
     * **********************
     * Getter.
     * **********************
     */
    public double[][] getError(int paraRecordId, int paraMapNo) {
        return errors[paraRecordId][paraMapNo];
    }// Of getError

    /**
     * **********************
     * Getter.
     * **********************
     */
    public double[][] getMap(int paraRecordId, int paraMapNo) {
        return outMaps[paraRecordId][paraMapNo];
    }// Of getMap

    /**
     * **********************
     * Getter.
     * **********************
     */
    public int getClassNum() {
        return classNum;
    }// Of getClassNum

    /**
     * **********************
     * Getter. Get the whole kernel tensor.
     * **********************
     */
    public double[][][][] getKernel() {
        return kernel;
    } // Of getKernel
}//Of class CnnLayer

在 CnnLayer 类上再封装一层, 用于更加简便地创建神经网络中的各层.

package cnn;

import java.util.ArrayList;
import java.util.List;

/**
 * CnnLayer builder.
 *
 * @author Shi-Huai Wen Email: [email protected].
 */
public class LayerBuilder {
    /**
     * Layers.
     */
    private List<CnnLayer> layers;

    /**
     * **********************
     * The first constructor.
     * **********************
     */
    public LayerBuilder() {
        layers = new ArrayList<>();
    }// Of the first constructor

    /**
     * **********************
     * The second constructor.
     * **********************
     */
    public LayerBuilder(CnnLayer paraLayer) {
        this();
        layers.add(paraLayer);
    }// Of the second constructor

    /**
     * **********************
     * Add a layer.
     *
     * @param paraLayer The new layer.
     * **********************
     */
    public void addLayer(CnnLayer paraLayer) {
        layers.add(paraLayer);
    }// Of addLayer

    /**
     * **********************
     * Get the specified layer.
     *
     * @param paraIndex The index of the layer.
     * **********************
     */
    public CnnLayer getLayer(int paraIndex) throws RuntimeException {
        if (paraIndex >= layers.size()) {
            throw new RuntimeException("CnnLayer " + paraIndex + " is out of range: " + layers.size() + ".");
        }//Of if

        return layers.get(paraIndex);
    }//Of getLayer

    /**
     * **********************
     * Get the output layer.
     * **********************
     */
    public CnnLayer getOutputLayer() {
        return layers.get(layers.size() - 1);
    }//Of getOutputLayer

    /**
     * **********************
     * Get the number of layers.
     * **********************
     */
    public int getNumLayers() {
        return layers.size();
    }//Of getNumLayers
} //Of class LayerBuilder

五、神经网络的搭建

1. 正向传播

正向传播的基本内容在之前已经提到了, 这里简述一下.

一张图片通过卷积核得到特征图, 然后特征图通过自己选择的池化层进行池化, 最后使用激活函数对池化层进行激活, 并把激活后的输出做为下一个卷积层的输入.

在重复卷积、池化、激活后进入全连接层. 全连接层中也有一个卷积过程, 他是把 m × n m \times n m×n 的特征图转换为 1 × n 1 \times n 1×n 的向量, 然后这个向量通过 S o f t m a x Softmax Softmax 函数进行处理并归一化. 这时候这个向量中最大值的下标就表示是最有可能的类别的下标.

2. 反向传播

反向传播这是一个老生常谈的问题了, 因为开始的卷积核是随机的, 所以就需要利用损失函数找到最佳的卷积核.

反向传播最开始更新的是全连接层, 它的反向传播和 ANN 网络中类似, 就是更新其中的权值.

然后就是池化层, 池化层的权值更新是最简单的. 以最大池化举例, 假设池化后的值为 6, 反向传播得到的误差为 +1, 反向传播回去得到池化前的值就是 6 + 1 = 7 6 + 1 = 7 6+1=7.

最麻烦的就是卷积层, 其中的公式推导我还是没有太弄清楚. 大致理解就是从二维出发得到了一个二维的公式, 然后将二维推广到神经网络中的多维.

知乎文章: 卷积神经网络(CNN)反向传播算法推导 有详细的推导和解释.

4. 具体代码

package cnn;

import java.util.Arrays;

import cnn.Dataset.Instance;
import cnn.MathUtils.Operator;

/**
 * CNN.
 *
 * @author Shi-Huai Wen Email: [email protected].
 */
public class FullCnn {
    /**
     * The value changes.
     */
    private static double ALPHA = 0.85;

    /**
     * A constant.
     */
    public static double LAMBDA = 0;

    /**
     * Manage layers.
     */
    private static LayerBuilder layerBuilder;

    /**
     * Train using a number of instances simultaneously.
     */
    private int batchSize;

    /**
     * Divide the batch size with the given value.
     */
    private Operator divideBatchSize;

    /**
     * Multiply alpha with the given value.
     */
    private Operator multiplyAlpha;

    /**
     * Multiply lambda and alpha with the given value.
     */
    private Operator multiplyLambda;

    /**
     * **********************
     * The first constructor.
     * **********************
     */
    public FullCnn(LayerBuilder paraLayerBuilder, int paraBatchSize) {
        layerBuilder = paraLayerBuilder;
        batchSize = paraBatchSize;
        setup();
        initOperators();
    }// Of the first constructor

    /**
     * **********************
     * Initialize operators using temporary classes.
     * **********************
     */
    private void initOperators() {
        divideBatchSize = new Operator() {
            private static final long serialVersionUID = 7424011281732651055L;

            @Override
            public double process(double value) {
                return value / batchSize;
            }// Of process
        };

        multiplyAlpha = new Operator() {
            private static final long serialVersionUID = 5761368499808006552L;

            @Override
            public double process(double value) {
                return value * ALPHA;
            }// Of process
        };

        multiplyLambda = new Operator() {
            private static final long serialVersionUID = 4499087728362870577L;

            @Override
            public double process(double value) {
                return value * (1 - LAMBDA * ALPHA);
            }// Of process
        };
    }// Of initOperators

    /**
     * **********************
     * Setup according to the layer builder.
     * **********************
     */
    public void setup() {
        CnnLayer tempInputLayer = layerBuilder.getLayer(0);
        tempInputLayer.initOutMaps(batchSize);

        for (int i = 1; i < layerBuilder.getNumLayers(); i++) {
            CnnLayer tempLayer = layerBuilder.getLayer(i);
            CnnLayer tempFrontLayer = layerBuilder.getLayer(i - 1);
            int tempFrontMapNum = tempFrontLayer.getOutMapNum();
            switch (tempLayer.getType()) {
                case INPUT:
                    // Should not be input. Maybe an error should be thrown out.
                    break;
                case CONVOLUTION:
                    tempLayer.setMapSize(
                            tempFrontLayer.getMapSize().subtract(tempLayer.getKernelSize(), 1));
                    tempLayer.initKernel(tempFrontMapNum);
                    tempLayer.initBias();
                    tempLayer.initErrors(batchSize);
                    tempLayer.initOutMaps(batchSize);
                    break;
                case SAMPLING:
                    tempLayer.setOutMapNum(tempFrontMapNum);
                    tempLayer.setMapSize(tempFrontLayer.getMapSize().divide(tempLayer.getScaleSize()));
                    tempLayer.initErrors(batchSize);
                    tempLayer.initOutMaps(batchSize);
                    break;
                case OUTPUT:
                    tempLayer.initOutputKernel(tempFrontMapNum, tempFrontLayer.getMapSize());
                    tempLayer.initBias();
                    tempLayer.initErrors(batchSize);
                    tempLayer.initOutMaps(batchSize);
                    break;
            }// Of switch
        } // Of for i
    }// Of setup

    /**
     * **********************
     * Forward computing.
     * **********************
     */
    private void forward(Instance instance) {
        setInputLayerOutput(instance);
        for (int l = 1; l < layerBuilder.getNumLayers(); l++) {
            CnnLayer tempCurrentLayer = layerBuilder.getLayer(l);
            CnnLayer tempLastLayer = layerBuilder.getLayer(l - 1);
            switch (tempCurrentLayer.getType()) {
                case CONVOLUTION:
                case OUTPUT:
                    setConvolutionOutput(tempCurrentLayer, tempLastLayer);
                    break;
                case SAMPLING:
                    setSampOutput(tempCurrentLayer, tempLastLayer);
                    break;
                default:
                    break;
            }// Of switch
        } // Of for l
    }// Of forward

    /**
     * **********************
     * Set the in layer output. Given a record, copy its values to the input map.
     * **********************
     */
    private void setInputLayerOutput(Instance paraRecord) {
        CnnLayer tempInputLayer = layerBuilder.getLayer(0);
        Size tempMapSize = tempInputLayer.getMapSize();
        double[] tempAttributes = paraRecord.getAttributes();
        if (tempAttributes.length != tempMapSize.width * tempMapSize.height)
            throw new RuntimeException("input record does not match the map size.");

        for (int i = 0; i < tempMapSize.width; i++) {
            for (int j = 0; j < tempMapSize.height; j++) {
                tempInputLayer.setMapValue(0, i, j, tempAttributes[tempMapSize.height * i + j]);
            } // Of for j
        } // Of for i
    }// Of setInputLayerOutput

    /**
     * **********************
     * Compute the convolution output according to the output of the last layer.
     *
     * @param paraLastLayer the last layer.
     * @param paraLayer     the current layer.
     * **********************
     */
    private void setConvolutionOutput(final CnnLayer paraLayer, final CnnLayer paraLastLayer) {

        final int lastMapNum = paraLastLayer.getOutMapNum();

        // Attention: paraLayer.getOutMapNum() may not be right.
        for (int j = 0; j < paraLayer.getOutMapNum(); j++) {
            double[][] tempSumMatrix = null;
            for (int i = 0; i < lastMapNum; i++) {
                double[][] lastMap = paraLastLayer.getMap(i);
                double[][] kernel = paraLayer.getKernel(i, j);
                if (tempSumMatrix == null) {
                    // On the first map.
                    tempSumMatrix = MathUtils.convnValid(lastMap, kernel);
                } else {
                    // Sum up convolution maps
                    tempSumMatrix = MathUtils.matrixOp(MathUtils.convnValid(lastMap, kernel),
                            tempSumMatrix, null, null, MathUtils.plus);
                } // Of if
            } // Of for i

            // Activation.
            final double bias = paraLayer.getBias(j);
            tempSumMatrix = MathUtils.matrixOp(tempSumMatrix, new Operator() {
                private static final long serialVersionUID = 2469461972825890810L;

                @Override
                public double process(double value) {
                    return MathUtils.sigmoid(value + bias);
                }

            });

            paraLayer.setMapValue(j, tempSumMatrix);
        } // Of for j
    }// Of setConvolutionOutput

    /**
     * **********************
     * Compute the convolution output according to the output of the last layer.
     *
     * @param paraLastLayer the last layer.
     * @param paraLayer     the current layer.
     * **********************
     */
    private void setSampOutput(final CnnLayer paraLayer, final CnnLayer paraLastLayer) {
        // int tempLastMapNum = paraLastLayer.getOutMapNum();

        // Attention: paraLayer.outMapNum may not be right.
        for (int i = 0; i < paraLayer.outMapNum; i++) {
            double[][] lastMap = paraLastLayer.getMap(i);
            Size scaleSize = paraLayer.getScaleSize();
            double[][] sampMatrix = MathUtils.scaleMatrix(lastMap, scaleSize);
            paraLayer.setMapValue(i, sampMatrix);
        } // Of for i
    }// Of setSampOutput

    /**
     * **********************
     * Train the cnn.
     * **********************
     */
    public void train(Dataset paraDataset, int paraRounds) {
        for (int t = 0; t < paraRounds; t++) {
            System.out.println("Iteration: " + t);
            int tempNumEpochs = paraDataset.size() / batchSize;
            if (paraDataset.size() % batchSize != 0)
                tempNumEpochs++;

            double tempNumCorrect = 0;
            int tempCount = 0;
            for (int i = 0; i < tempNumEpochs; i++) {
                int[] tempRandomPerm = MathUtils.randomPerm(paraDataset.size(), batchSize);
                CnnLayer.prepareForNewBatch();

                for (int index : tempRandomPerm) {
                    boolean isRight = train(paraDataset.getInstance(index));
                    if (isRight)
                        tempNumCorrect++;
                    tempCount++;
                    CnnLayer.prepareForNewRecord();
                } // Of for index

                updateParameters();
                if (i % 50 == 0) {
                    System.out.print("..");
                    if (i + 50 > tempNumEpochs)
                        System.out.println();
                }
            }
            double p = 1.0 * tempNumCorrect / tempCount;
            if (t % 10 == 1 && p > 0.96) {
                ALPHA = 0.001 + ALPHA * 0.9;
                // logger.info("设置 alpha = {}", ALPHA);
            } // Of iff
            System.out.println("Training precision: " + p);
            // logger.info("计算精度: {}/{}={}.", right, count, p);
        } // Of for i
    }// Of train

    /**
     * **********************
     * Train the cnn with only one record.
     *
     * @param paraRecord The given record.
     * **********************
     */
    private boolean train(Instance paraRecord) {
        forward(paraRecord);
        boolean result = backPropagation(paraRecord);
        return result;
    }// Of train

    /**
     * **********************
     * Back-propagation.
     *
     * @param paraRecord The given record.
     * **********************
     */
    private boolean backPropagation(Instance paraRecord) {
        boolean result = setOutputLayerErrors(paraRecord);
        setHiddenLayerErrors();
        return result;
    }// Of backPropagation

    /**
     * **********************
     * Update parameters.
     * **********************
     */
    private void updateParameters() {
        for (int l = 1; l < layerBuilder.getNumLayers(); l++) {
            CnnLayer layer = layerBuilder.getLayer(l);
            CnnLayer lastLayer = layerBuilder.getLayer(l - 1);
            switch (layer.getType()) {
                case CONVOLUTION:
                case OUTPUT:
                    updateKernels(layer, lastLayer);
                    updateBias(layer, lastLayer);
                    break;
                default:
                    break;
            }// Of switch
        } // Of for l
    }// Of updateParameters

    /**
     * **********************
     * Update bias.
     * **********************
     */
    private void updateBias(final CnnLayer paraLayer, CnnLayer paraLastLayer) {
        final double[][][][] errors = paraLayer.getErrors();
        // int mapNum = paraLayer.getOutMapNum();

        // Attention: getOutMapNum() may not be correct.
        for (int j = 0; j < paraLayer.getOutMapNum(); j++) {
            double[][] error = MathUtils.sum(errors, j);
            double deltaBias = MathUtils.sum(error) / batchSize;
            double bias = paraLayer.getBias(j) + ALPHA * deltaBias;
            paraLayer.setBias(j, bias);
        } // Of for j
    }// Of updateBias

    /**
     * **********************
     * Update kernels.
     * **********************
     */
    private void updateKernels(final CnnLayer paraLayer, final CnnLayer paraLastLayer) {
        // int mapNum = paraLayer.getOutMapNum();
        int tempLastMapNum = paraLastLayer.getOutMapNum();

        // Attention: getOutMapNum() may not be right
        for (int j = 0; j < paraLayer.getOutMapNum(); j++) {
            for (int i = 0; i < tempLastMapNum; i++) {
                double[][] tempDeltaKernel = null;
                for (int r = 0; r < batchSize; r++) {
                    double[][] error = paraLayer.getError(r, j);
                    if (tempDeltaKernel == null)
                        tempDeltaKernel = MathUtils.convnValid(paraLastLayer.getMap(r, i), error);
                    else {
                        tempDeltaKernel = MathUtils.matrixOp(
                                          MathUtils.convnValid(paraLastLayer.getMap(r, i), error),
                                          tempDeltaKernel, null, null, MathUtils.plus);
                    } // Of if
                } // Of for r

                tempDeltaKernel = MathUtils.matrixOp(tempDeltaKernel, divideBatchSize);

                double[][] kernel = paraLayer.getKernel(i, j);
                tempDeltaKernel = MathUtils.matrixOp(kernel, tempDeltaKernel, multiplyLambda, multiplyAlpha, MathUtils.plus);
                paraLayer.setKernel(i, j, tempDeltaKernel);
            } // Of for i
        } // Of for j
    }// Of updateKernels

    /**
     * **********************
     * Set errors of all hidden layers.
     * **********************
     */
    private void setHiddenLayerErrors() {
        // System.out.println("setHiddenLayerErrors");
        for (int l = layerBuilder.getNumLayers() - 2; l > 0; l--) {
            CnnLayer layer = layerBuilder.getLayer(l);
            CnnLayer nextLayer = layerBuilder.getLayer(l + 1);

            switch (layer.getType()) {
                case SAMPLING:
                    setSamplingErrors(layer, nextLayer);
                    break;
                case CONVOLUTION:
                    setConvolutionErrors(layer, nextLayer);
                    break;
                default:
                    break;
            }// Of switch
        } // Of for l
    }// Of setHiddenLayerErrors

    /**
     * **********************
     * Set errors of a sampling layer.
     * **********************
     */
    private void setSamplingErrors(final CnnLayer paraLayer, final CnnLayer paraNextLayer) {
        // int mapNum = layer.getOutMapNum();
        int tempNextMapNum = paraNextLayer.getOutMapNum();
        // Attention: getOutMapNum() may not be correct
        for (int i = 0; i < paraLayer.getOutMapNum(); i++) {
            double[][] sum = null;
            for (int j = 0; j < tempNextMapNum; j++) {
                double[][] nextError = paraNextLayer.getError(j);
                double[][] kernel = paraNextLayer.getKernel(i, j);
                if (sum == null) {
                    sum = MathUtils.convnFull(nextError, MathUtils.rot180(kernel));
                } else {
                    sum = MathUtils.matrixOp(
                            MathUtils.convnFull(nextError, MathUtils.rot180(kernel)), sum, null,
                            null, MathUtils.plus);
                } // Of if
            } // Of for j
            paraLayer.setError(i, sum);
        } // Of for i
    }// Of setSamplingErrors

    /**
     * **********************
     * Set errors of a sampling layer.
     * **********************
     */
    private void setConvolutionErrors(final CnnLayer paraLayer, final CnnLayer paraNextLayer) {
        // System.out.println("setConvErrors");
        for (int m = 0; m < paraLayer.getOutMapNum(); m++) {
            Size tempScale = paraNextLayer.getScaleSize();
            double[][] tempNextLayerErrors = paraNextLayer.getError(m);
            double[][] tempMap = paraLayer.getMap(m);
            double[][] tempOutMatrix = MathUtils.matrixOp(tempMap, MathUtils.cloneMatrix(tempMap),
                    null, MathUtils.one_value, MathUtils.multiply);
            tempOutMatrix = MathUtils.matrixOp(tempOutMatrix,
                    MathUtils.kronecker(tempNextLayerErrors, tempScale), null, null,
                    MathUtils.multiply);
            paraLayer.setError(m, tempOutMatrix);
        } // Of for m
    }// Of setConvolutionErrors

    /**
     * **********************
     * Set errors of a sampling layer.
     * **********************
     */
    private boolean setOutputLayerErrors(Instance paraRecord) {
        CnnLayer tempOutputLayer = layerBuilder.getOutputLayer();
        int tempMapNum = tempOutputLayer.getOutMapNum();

        double[] tempTarget = new double[tempMapNum];
        double[] tempOutMaps = new double[tempMapNum];
        for (int m = 0; m < tempMapNum; m++) {
            double[][] outmap = tempOutputLayer.getMap(m);
            tempOutMaps[m] = outmap[0][0];
        } // Of for m

        int tempLabel = paraRecord.getLabel().intValue();
        tempTarget[tempLabel] = 1;

        for (int m = 0; m < tempMapNum; m++) {
            tempOutputLayer.setError(m, 0, 0,
                    tempOutMaps[m] * (1 - tempOutMaps[m]) * (tempTarget[m] - tempOutMaps[m]));
        } // Of for m

        return tempLabel == MathUtils.getMaxIndex(tempOutMaps);
    }// Of setOutputLayerErrors

    /**
     * **********************
     * Setup the network.
     * **********************
     */
    public void setup(int paraBatchSize) {
        CnnLayer tempInputLayer = layerBuilder.getLayer(0);
        tempInputLayer.initOutMaps(paraBatchSize);

        for (int i = 1; i < layerBuilder.getNumLayers(); i++) {
            CnnLayer tempLayer = layerBuilder.getLayer(i);
            CnnLayer tempLastLayer = layerBuilder.getLayer(i - 1);
            int tempLastMapNum = tempLastLayer.getOutMapNum();
            switch (tempLayer.getType()) {
                case INPUT:
                    break;
                case CONVOLUTION:
                    tempLayer.setMapSize(
                            tempLastLayer.getMapSize().subtract(tempLayer.getKernelSize(), 1));
                    tempLayer.initKernel(tempLastMapNum);
                    tempLayer.initBias();
                    tempLayer.initErrors(paraBatchSize);
                    tempLayer.initOutMaps(paraBatchSize);
                    break;
                case SAMPLING:
                    tempLayer.setOutMapNum(tempLastMapNum);
                    tempLayer.setMapSize(tempLastLayer.getMapSize().divide(tempLayer.getScaleSize()));
                    tempLayer.initErrors(paraBatchSize);
                    tempLayer.initOutMaps(paraBatchSize);
                    break;
                case OUTPUT:
                    tempLayer.initOutputKernel(tempLastMapNum, tempLastLayer.getMapSize());
                    tempLayer.initBias();
                    tempLayer.initErrors(paraBatchSize);
                    tempLayer.initOutMaps(paraBatchSize);
                    break;
            }// Of switch
        } // Of for i
    }// Of setup

    /**
     * **********************
     * Predict for the dataset.
     * **********************
     */
    public int[] predict(Dataset paraDataset) {
        System.out.println("Predicting ... ");
        CnnLayer.prepareForNewBatch();

        int[] resultPredictions = new int[paraDataset.size()];
        double tempCorrect = 0.0;

        Instance tempRecord;
        for (int i = 0; i < paraDataset.size(); i++) {
            tempRecord = paraDataset.getInstance(i);
            forward(tempRecord);
            CnnLayer outputLayer = layerBuilder.getOutputLayer();

            int tempMapNum = outputLayer.getOutMapNum();
            double[] tempOut = new double[tempMapNum];
            for (int m = 0; m < tempMapNum; m++) {
                double[][] outmap = outputLayer.getMap(m);
                tempOut[m] = outmap[0][0];
            } // Of for m

            resultPredictions[i] = MathUtils.getMaxIndex(tempOut);
            if (resultPredictions[i] == tempRecord.getLabel().intValue()) {
                tempCorrect++;
            } // Of if
        } // Of for

        System.out.println("Accuracy: " + tempCorrect / paraDataset.size());
        return resultPredictions;
    }// Of predict

    /**
     * **********************
     * The main entrance.
     * **********************
     */
    public static void main(String[] args) {
        LayerBuilder builder = new LayerBuilder();
        // Input layer, the maps are 28*28
        builder.addLayer(new CnnLayer(LayerTypeEnum.INPUT, -1, new Size(28, 28)));
        // Convolution output has size 24*24, 24=28+1-5
        builder.addLayer(new CnnLayer(LayerTypeEnum.CONVOLUTION, 6, new Size(5, 5)));
        // Sampling output has size 12*12,12=24/2
        builder.addLayer(new CnnLayer(LayerTypeEnum.SAMPLING, -1, new Size(2, 2)));
        // Convolution output has size 8*8, 8=12+1-5
        builder.addLayer(new CnnLayer(LayerTypeEnum.CONVOLUTION, 12, new Size(5, 5)));
        // Sampling output has size4×4,4=8/2
        builder.addLayer(new CnnLayer(LayerTypeEnum.SAMPLING, -1, new Size(2, 2)));
        // output layer, digits 0 - 9.
        builder.addLayer(new CnnLayer(LayerTypeEnum.OUTPUT, 10, null));
        // Construct the full CNN.
        FullCnn tempCnn = new FullCnn(builder, 10);

        Dataset tempTrainingSet = new Dataset("D:/Work/Data/sampledata/train.format", ",", 784);

        // Train the model.
        tempCnn.train(tempTrainingSet, 10);
        // tempCnn.predict(tempTrainingSet);
    }// Of main
}//Of class FullCnn

5. 运行截图

Java学习(Day 37)_第4张图片

总结

卷积神经网络理解起来容易, 但是实际编写一个框架对我来说就是非常痛苦且困难的事情.

首先是反向传播时数学公式的推导, 知道梯度下降和矩阵求导, 这些也仅仅是在单一的练习题中完成, 当实际运用时就找不到门路.

再者是代码的编写, 不管是数学的工具类还是矩阵的工具类, 尤其是矩阵旋转那部分刚开始看完全不理解.

要是我们只需要利用公式去编写代码, 想必这些问题就会迎刃而解. 偏偏天不遂人愿, 这些都是需要我真真实实去感受、去推导、去实现的.

路漫漫其修远兮, 吾将上下而求索.

你可能感兴趣的:(学习,cnn)