mnist手写数据集实现--java版本

综述:

目标:从零实现一个神经网络,并用mnist数据集进行训练,最后实现拿到一张图片能够识别它是0-9之间的数字。

工具:java1.8+ 

工具包:nd4j

一、数据分析

1.mnist数据集的介绍

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:

  • Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
  • Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
  • Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
  • Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.

2.mnist工具类的实现

package com.yan.dl4j.Utils;


import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;

public class MnistReadUtil {
    public static final String TRAIN_IMAGES_FILE = "src\\main\\resources\\data\\train-images.idx3-ubyte";
    public static final String TRAIN_LABELS_FILE = "src\\main\\resources\\data\\train-labels.idx1-ubyte";
    public static final String TEST_IMAGES_FILE = "src\\main\\resources\\data\\t10k-images.idx3-ubyte";
    public static final String TEST_LABELS_FILE = "src\\main\\resources\\data\\t10k-labels.idx1-ubyte";

    /**
     * change bytes into a hex string.
     *
     * @param bytes bytes
     * @return the returned hex string
     */
    public static String bytesToHex(byte[] bytes) {
        StringBuffer sb = new StringBuffer();
        for (int i = 0; i < bytes.length; i++) {
            String hex = Integer.toHexString(bytes[i] & 0xFF);
            if (hex.length() < 2) {
                sb.append(0);
            }
            sb.append(hex);
        }
        return sb.toString();
    }

    /**
     * get images of 'train' or 'test'
     *
     * @param fileName the file of 'train' or 'test' about image
     * @return one row show a `picture`
     */
    public static double[][] getImages(String fileName) {
        try{
            return getImages(new FileInputStream(fileName));
        }catch (FileNotFoundException e){
            e.printStackTrace();
        }
        return null;
    }

    public static double[][] getImages(InputStream inputStream) {
        double[][] x;
        try (BufferedInputStream bin = new BufferedInputStream(inputStream)) {
            byte[] bytes = new byte[4];
            bin.read(bytes, 0, 4);
            if (!"00000803".equals(bytesToHex(bytes))) {                        // 读取魔数
                throw new RuntimeException("Please select the correct file!");
            } else {
                bin.read(bytes, 0, 4);
                int number = Integer.parseInt(bytesToHex(bytes), 16);           // 读取样本总数
                bin.read(bytes, 0, 4);
                int xPixel = Integer.parseInt(bytesToHex(bytes), 16);           // 读取每行所含像素点数
                bin.read(bytes, 0, 4);
                int yPixel = Integer.parseInt(bytesToHex(bytes), 16);           // 读取每列所含像素点数
                x = new double[number][xPixel * yPixel];
                for (int i = 0; i < number; i++) {
                    double[] element = new double[xPixel * yPixel];
                    for (int j = 0; j < xPixel * yPixel; j++) {
                        element[j] = bin.read();                                // 逐一读取像素值
                        // normalization
//                        element[j] = bin.read() / 255.0;
                    }
                    x[i] = element;
                }
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return x;
    }

    /**
     * get labels of `train` or `test`
     *
     * @param fileName the file of 'train' or 'test' about label
     * @return lables
     */
    public static double[] getLabels(String fileName) {
        try{
            return  getLabels(new FileInputStream(fileName));
        }catch (FileNotFoundException e){
            e.printStackTrace();
        }
        return null;
    }

    public static double[] getLabels(InputStream inputStream) {
        double[] y;
        try (BufferedInputStream bin = new BufferedInputStream(inputStream)) {
            byte[] bytes = new byte[4];
            bin.read(bytes, 0, 4);
            if (!"00000801".equals(bytesToHex(bytes))) {
                throw new RuntimeException("Please select the correct file!");
            } else {
                bin.read(bytes, 0, 4);
                int number = Integer.parseInt(bytesToHex(bytes), 16);
                y = new double[number];
                for (int i = 0; i < number; i++) {
                    y[i] = bin.read();
                }
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return y;
    }

    public static void drawGrayPicture(double[] pixelValues, String fileName) throws IOException {
        //double转int
        int[] res = new int[pixelValues.length];
        for(int i=0;i> 16;
                rgb[1] = (r & 0xff00) >> 8;
                rgb[2] = (r & 0xff);
                int color = 255 - (int)(rgb[0]* 0.3 + rgb[1] * 0.59 + rgb[2] * 0.11);
                result[i*height+j] = color;
            }
        }
        return result;
    }
}

此工具类getImages方法实现mnist数据集转化为double[][]的类型,以train-images.idx3-ubyte为例,转化后的double[][]为double[60000][784],每个double[],代表着一张图,一共60000张图,每张图为28*28=784个特征值。getLabels方法实现mnist数据集标签转化为double[]类型。drawGrayPicture方法把double[]转化为一张人可观察的图。getSizeBlackWhiteImg方法实现把得到的一张图灰度化,并压缩成28*28的图,然后转化为double[]。

二、神经网络搭建

1.神经网络属性

private List layers = new ArrayList<>();
private LastLayer lastLayer;
private INDArray[] Network_W;
private INDArray[] Network_B;
private int nin;
private int seed=123;

private double learningrate=0.01;

private int iteration = 10;

    nin为数据输入维度 Layer为神经网络层的抽象 LastLayer 为最后一层,之所以分开是因为最后一层包含损失函数。Network_W

为每一层的W值矩阵,Network_B为每一层的B值,learningrate为学习率,iteration为迭代次数。seed为初始化随机数种子。

2.思路步骤:

(1)获取搭建神经元必要参数:输入数据维度、神经元层数、每层神经元个数。

public DeepNeuralNetWork(int nin){
    this.nin = nin;
}
public DeepNeuralNetWork addLayer(Layer layer){
    layers.add(layer);
    return this;
}
public DeepNeuralNetWork addLastLayer(LastLayer lastLayer){
    this.lastLayer = lastLayer;
    Init();
    return this;
}

    通过构造函数获取输入数据维度,通过加入layer层获取每层的信息。Layer是一个接口,从中可以获取到每层的激活函数、每层神经元个数等等。LastLayer 为最后一层类。与其他不同的是多了一个损失函数。得到最后一层之后就可以进行参数初始化。

(2)通过必要参数初始化神经网络。

public void Init(){
    Network_W = new INDArray[layers.size()+1];
    Network_B = new INDArray[layers.size()+1];
    for(int i=0;i0){
                Network_W[i] = layers.get(i).getWinit().Init(seed,layers.get(i).getNeuralNumber(), nin);
                Network_B[i] = Nd4j.zeros(layers.get(i).getNeuralNumber(), 1);
            }else{
                Network_W[i] = lastLayer.getWinit().Init(seed,lastLayer.getNeuralNumber(), nin);
                Network_B[i] = Nd4j.zeros(lastLayer.getNeuralNumber(), 1);
            }
        }else if(i==layers.size()){ //最后一个
            Network_W[i] = lastLayer.getWinit().Init(seed,lastLayer.getNeuralNumber(), layers.get(i-1).getNeuralNumber());
            Network_B[i] = Nd4j.zeros(lastLayer.getNeuralNumber(), 1);
        }else{
            Network_W[i] = layers.get(i).getWinit().Init(seed,layers.get(i).getNeuralNumber(), layers.get(i-1).getNeuralNumber());
            Network_B[i] = Nd4j.zeros(layers.get(i).getNeuralNumber(), 1);
        }
    }
}

(3)向前传播

private INDArray linear_forward(INDArray A, INDArray W, INDArray b){
    return W.mmul(A).addColumnVector(b);
}
private INDArray linear_activate_forward(INDArray A_p, INDArray W, INDArray b, ActivateMethod activate){
    if(activate!=null){
        return activate.activate_forward(linear_forward(A_p,W,b));
    }
    return linear_forward(A_p,W,b);
}
private INDArray[] forward(INDArray X){
    INDArray[] res = new INDArray[layers.size()+1];
    INDArray P_A = X;
    for(int i=0;i

(4)反向传播

private INDArray LossBackward(INDArray A,INDArray Y){
    return lastLayer.LastBackward(A,Y);
}

private double LossForward(INDArray A,INDArray Y){
    return lastLayer.getLossMethod().LossForward(A,Y);
}


private List backward(INDArray[] A_array,INDArray x,INDArray Y){
    INDArray DZ = LossBackward(A_array[A_array.length-1],Y);
    INDArray[] DW = new INDArray[A_array.length];
    INDArray[] DB = new INDArray[A_array.length];
    List res = new ArrayList<>();
    for(int i=A_array.length-1;i>=0;i--){
        if(i==0){ //最后一次
            INDArray dW = DZ.mmul(x.transpose());
            INDArray dB = DZ.mmul(Nd4j.ones(x.shape()[1],1));
            DW[i] = dW.div(x.shape()[1]);
            DB[i] = dB.div(x.shape()[1]);
        }else{
            INDArray dW = DZ.mmul(A_array[i-1].transpose());
            INDArray dB = DZ.mmul(Nd4j.ones(x.shape()[1],1));
            DW[i] = dW.div(x.shape()[1]);
            DB[i] = dB.div(x.shape()[1]);
            DZ = activate_backward(Network_W[i].transpose().mmul(DZ),A_array[i-1],layers.get(i-1).getActivateMethod());
        }
    }
    res.add(DW);
    res.add(DB);
    return res;
}

(5)梯度下降

private void update_parameters(List DW_DB){
    INDArray[] DW =  DW_DB.get(0);
    INDArray[] DB =  DW_DB.get(1);
    for(int i=0;i

(6)loss值计算

private double LossForward(INDArray A,INDArray Y){
    return lastLayer.getLossMethod().LossForward(A,Y);
}
double loss = LossForward(A[A.length - 1],batch_list.get(j).getY());

(7)组合起来并训练

@Override
public void train(TrainData data) {
    for(int i=0;i batch_list = data.getBatchList();
        for(int j=0;j DW_DB = backward(A, batch_list.get(j).getX(), batch_list.get(j).getY()); //反向传播

            update_parameters(DW_DB); //梯度下降

            double loss = LossForward(A[A.length - 1],batch_list.get(j).getY());

            //打印情况
            System.out.println("i=" + (i*batch_list.size()+j));
            System.out.println("loss=" + loss);
        }
    }
}

(8)对输入数据进行预测

@Override
public INDArray predict(INDArray x) {
    INDArray[] A = forward(x);
    return A[A.length-1];
}

(9)其他辅助类

//数学工具类
public class MyMathUtil {
    public static double Epow(double x){
        return Math.pow(Math.E, x);//e^x
    }

    public static INDArray Epow(INDArray value){
        return FUN_IND(value,v->Epow(v));
    }

    public static double Normalization(double value,double Max){
        return value/Max;
    }

    public static double MaxValue(INDArray value){
        if(value.shape()[0]>1){
            double[][] s = value.toDoubleMatrix();
            double my_Max=s[0][0];
            for(double[] si:s){
                for(double sj:si){
                    my_Max = Math.max(my_Max,sj);
                }
            }
            return my_Max;
        }else{
            double[] s = value.toDoubleVector();
            double my_Max=s[0];
            for(double si:s){
                my_Max = Math.max(my_Max,si);
            }
            return my_Max;
        }
    }

    public static INDArray Normalization(INDArray value){
        double my_Max = MaxValue(value);
        return FUN_IND(value,s->s/my_Max);
    }

    public static INDArray indArraysubMax(INDArray value){
        if(value!=null){
            if(value.shape()[0]>1&&value.shape()[1]>1){
                double[][] s = value.transpose().toDoubleMatrix();
                for(int i=0;isi?Max:si;
            }
            double[][] one_hot_res= new double[s.length][Max+1];
            for(int i=0;i doubleFunction){
        if(value!=null){
            if(value.shape()[0]>1&&value.shape()[1]>1){
                double[][] s = value.toDoubleMatrix();
                for(int i=0;iMysigMoid(v));
    }

    public static double Mytanh(double value) {
        double ex = Math.pow(Math.E, value);// e^x
        double ey = Math.pow(Math.E, -value);//e^(-x)
        double sinhx = ex-ey;
        double coshx = ex+ey;
        return sinhx/coshx;
    }
    public static INDArray Mytanh(INDArray value) {
        return FUN_IND(value,v->Mytanh(v));
    }

    public static double relu(double value) {
        return Math.max(0,value);
    }

    public static INDArray relu(INDArray value) {
        return FUN_IND(value,v->relu(v));
    }

    public static double relu_back(double value) {
        if(value>0){
            return value;
        }else{
            return 0;
        }
    }

    public static INDArray relu_back(INDArray value) {
        return FUN_IND(value,v->relu_back(v));
    }

    public static double Log(double value) {
        return Math.log(value);
    }

    public static INDArray Log(INDArray value) {
        return FUN_IND(value,v->Log(v));
    }

    public static INDArray sotfmax(INDArray A){
        if(A!=null){
            A = MyMathUtil.Epow(A);  //A: 10,128
            INDArray sum_A = Nd4j.ones(1,A.shape()[0]).mmul(A); //1,128
            if(A.shape()[0]>1&&A.shape()[1]>1){
                double[][] A_s = A.transpose().toDoubleMatrix(); //128 10
                double[] SUM_A_s = sum_A.toDoubleVector();
                for(int i=0;i
//激活方法类
public interface ActivateMethod {
    INDArray activate_forward(INDArray A);
    INDArray activate_backward(INDArray DA, INDArray A);
}
public class Relu implements ActivateMethod {

    @Override
    public INDArray activate_forward(INDArray A) {
        return MyMathUtil.relu(A);
    }

    @Override
    public INDArray activate_backward(INDArray DA, INDArray A) {
        return MyMathUtil.relu_back(DA);
    }
}
public class SoftMax implements ActivateMethod {
    @Override
    public INDArray activate_forward(INDArray A) {
        return MyMathUtil.sotfmax(MyMathUtil.indArraysubMax(A));
    }

    @Override
    public INDArray activate_backward(INDArray DA, INDArray A) {
        return MyMathUtil.sotfmax_back(DA,A);
    }
}
public class Tanh implements ActivateMethod {
    @Override
    public INDArray activate_forward(INDArray A) {
        return MyMathUtil.Mytanh(A);
    }

    @Override
    public INDArray activate_backward(INDArray DA, INDArray A) {
        return DA.mul(Nd4j.ones(A.shape()).sub(A.mul(A)));
    }
}

//层抽象与实现

public interface Layer {
    int getNeuralNumber();
    ActivateMethod getActivateMethod();
    Layer setActivateMethod(ActivateMethod activate);
    Layer setWInit(Winit wInit);
    Winit getWinit();
}
public interface LastLayer extends Layer {
    LossMethod getLossMethod();
    LastLayer setLossMethod(LossMethod lossMethod);
    default INDArray LastBackward(INDArray A,INDArray Y){
        return getActivateMethod().activate_backward(getLossMethod().LossBackward(A,Y),A);
    }
}
public class MyLayer implements Layer {

    private int number;
    private ActivateMethod activateMethod;
    private Winit winit;

    public MyLayer(int number,ActivateMethod method){
        this.number = number;
        this.activateMethod = method;
        this.winit = new XAVIER();
    }
    public MyLayer(int number,ActivateMethod method,Winit winit){
        this.number = number;
        this.activateMethod = method;
        this.winit = winit;
    }

    @Override
    public int getNeuralNumber() {
        return number;
    }


    @Override
    public ActivateMethod getActivateMethod() {
        return activateMethod;
    }

    @Override
    public Layer setActivateMethod(ActivateMethod activate) {
        this.activateMethod = activate;
        return this;
    }

    @Override
    public Layer setWInit(Winit winit) {
        this.winit = winit;
        return this;
    }

    @Override
    public Winit getWinit() {
        return winit;
    }
}
public class MyLastLayer extends MyLayer implements LastLayer {
    public MyLastLayer(int number,ActivateMethod method,LossMethod lossMethod){
        super(number,method);
        this.lossMethod = lossMethod;
    }
    private LossMethod lossMethod;
    @Override
    public LossMethod getLossMethod() {
        return lossMethod;
    }

    @Override
    public LastLayer setLossMethod(LossMethod lossMethod) {
        this.lossMethod = lossMethod;
        return this;
    }

}
//损失函数抽象与实现
public interface LossMethod {
    INDArray LossBackward(INDArray A,INDArray Y);
    double LossForward(INDArray A,INDArray Y);
}

 

public class CrossEntropy implements LossMethod {
    @Override
    public INDArray LossBackward(INDArray A, INDArray Y) {
        return Nd4j.zeros(Y.div(A).shape()).sub(Y.div(A));
    }

    @Override
    public double LossForward(INDArray A, INDArray Y) {
        INDArray los = Y.mul(MyMathUtil.Log(A));
        return  (0-los.sumNumber().doubleValue())/Y.shape()[1];
    }
}
public class MSE implements  LossMethod {
    @Override
    public INDArray LossBackward(INDArray A, INDArray Y) {
        return A.sub(Y);
    }

    @Override
    public double LossForward(INDArray A, INDArray Y) {
        return (Y.sub(A).mmul(Y.sub(A).transpose()).sumNumber().doubleValue())/Y.shape()[1];
    }
}

//w值初始化

public interface Winit {
    INDArray Init(int seed,int out,int in);
}
public class XAVIER implements Winit{

    @Override
    public INDArray Init(int seed,int out, int in) {
        return Nd4j.randn(out,in,seed).muli(FastMath.sqrt(2.0 / (in+out)));
    }
}
public class RELUWInit implements Winit {
    @Override
    public INDArray Init(int seed,int out, int in) {
        return Nd4j.randn(out,in,seed).muli(FastMath.sqrt(2.0 / in));
    }
}
public class RandWInit implements Winit {
    @Override
    public INDArray Init(int seed, int out, int in) {
        return Nd4j.rand(out,in,seed);
    }
}

 //训练数据类

public interface INData {
    INDArray getX();
    INDArray getY();
    int getSize();
}
public interface TrainData extends INData{
     List getBatchList();
}
public class BatchData implements INData{
    private INDArray x;
    private INDArray y;

    public BatchData(INDArray x,INDArray y){
        this.x = x;
        this.y = y;
    }
    @Override
    public INDArray getX() {
        return x;
    }

    @Override
    public INDArray getY() {
        return y;
    }

    @Override
    public int getSize() {
        return getX().columns();
    }

    public void setX(INDArray x) {
        this.x = x;
    }

    public void setY(INDArray y) {
        this.y = y;
    }
}
public class MyTrainData implements TrainData {
    private INDArray x;
    private INDArray y;
    private int batch_size;

    public MyTrainData(INDArray x,INDArray y,int batch_size){
        this.x = x.transpose();
        this.y = y.transpose();
        this.batch_size = batch_size;
    }
    public MyTrainData(INDArray x,INDArray y){
        this.x = x.transpose();
        this.y = y.transpose();
        this.batch_size = -1;
    }

    @Override
    public INDArray getX() {
        return x;
    }

    @Override
    public INDArray getY() {
        return y;
    }

    @Override
    public List getBatchList() {
        List res = new ArrayList<>();
        shufflecard();
        if(batch_size!=-1){
            int lastColumnOrder = 0;
            for(int i=batch_size;i

3.创建神经网络并数据输入训练:

public static final ClassPathResource TRAIN_IMAGES_FILE = new ClassPathResource("data/train-images.idx3-ubyte");
public static final ClassPathResource TRAIN_LABELS_FILE = new ClassPathResource("data/train-labels.idx1-ubyte");
public static final ClassPathResource TEST_IMAGES_FILE = new ClassPathResource("data/t10k-images.idx3-ubyte");
public static final ClassPathResource TEST_LABELS_FILE = new ClassPathResource("data/t10k-labels.idx1-ubyte");
private model pointmodel = new DeepNeuralNetWork(28*28)
        .addLayer(new MyLayer(1000,new Tanh()))
        .addLayer(new MyLayer(500,new Tanh()))
        .addLayer(new MyLayer(100,new Tanh()))
        .addLastLayer(new SotfMaxCrossEntropyLastLayer(10)).setIteration(10).setLearningrate(0.06);
double[][] images = MnistReadUtil.getImages(TRAIN_IMAGES_FILE.getInputStream());
        double[] labels = MnistReadUtil.getLabels(TRAIN_LABELS_FILE.getInputStream());
        INDArray X = Nd4j.create(images);  //60000,784
        INDArray Y = Nd4j.create(labels).transpose(); //60000,1
        INDArray X_I = MyMathUtil.Normalization(X);
        INDArray Y_I = MyMathUtil.ONEHOT(Y);//60000,10
        TrainData data = new MyTrainData(X_I,Y_I,128);
        pointmodel.train(data);

 

4.测试模型

double[][] t_images = MnistReadUtil.getImages(TEST_IMAGES_FILE.getInputStream());
double[] t_labels = MnistReadUtil.getLabels(TEST_LABELS_FILE.getInputStream());
INDArray X_t = MyMathUtil.Normalization(Nd4j.create(t_images));
INDArray Y_t = MyMathUtil.ONEHOT(Nd4j.create(t_labels).transpose());
TrainData data_t = new MyTrainData(X_t,Y_t);
INDArray X_P = pointmodel.predict(data_t.getX());
System.out.println("正确率:"+scord(X_P,data_t.getY())+"%");

//找出概率最大的值并与标签比较。

private float scord(INDArray value,INDArray Y) {
    int res = 0;
    int sum = 0;
        double[][] s = value.transpose().toDoubleMatrix();
        double[][] Ys = Y.transpose().toDoubleMatrix();
        for(int i=0;is[i][j]?order:j;
                Max = Max>s[i][j]?Max:s[i][j];
            }
            if(order>0&&new Double(Ys[i][order]).intValue()==1){
                res++;
            }
            sum++;
        }
        if(sum>0){
            return ((float)res/sum)*100;
        }else{
            return 0;
        }
}

5.输入图片并预测

public String predict(@RequestParam(value = "file") MultipartFile file, ModelMap map){
    if (file.isEmpty()) {
        System.out.println("文件为空空");
    }
    try{
        File my_file = File.createTempFile("tmp", null);
        file.transferTo(my_file);
        double[] m = MnistReadUtil.getSizeBlackWhiteImg(my_file,28,28);
        INDArray X_t = MyMathUtil.Normalization(Nd4j.create(m));
        INDArray X_P = pointmodel.predict(X_t.transpose());
        int number = getnumber(X_P);
        map.addAttribute ("number",number);
        return "freemarker/mnist/predict";
    }catch (Exception e){
        e.printStackTrace();
    }
    return "freemarker/fail";
}
private int getnumber(INDArray X){
    double[] s = X.toDoubleVector();
    int res = 0;
    double Max = s[0];
    for(int i=0;i

 

完整代码地址:[email protected]:woshiyigebing/my_dl4j.git

你可能感兴趣的:(java,深度学习,mnist,神经网络,机器学习,笔记,dl4j)