MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:
MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.
package com.yan.dl4j.Utils; import javax.imageio.ImageIO; import java.awt.*; import java.awt.image.BufferedImage; import java.io.*; public class MnistReadUtil { public static final String TRAIN_IMAGES_FILE = "src\\main\\resources\\data\\train-images.idx3-ubyte"; public static final String TRAIN_LABELS_FILE = "src\\main\\resources\\data\\train-labels.idx1-ubyte"; public static final String TEST_IMAGES_FILE = "src\\main\\resources\\data\\t10k-images.idx3-ubyte"; public static final String TEST_LABELS_FILE = "src\\main\\resources\\data\\t10k-labels.idx1-ubyte"; /** * change bytes into a hex string. * * @param bytes bytes * @return the returned hex string */ public static String bytesToHex(byte[] bytes) { StringBuffer sb = new StringBuffer(); for (int i = 0; i < bytes.length; i++) { String hex = Integer.toHexString(bytes[i] & 0xFF); if (hex.length() < 2) { sb.append(0); } sb.append(hex); } return sb.toString(); } /** * get images of 'train' or 'test' * * @param fileName the file of 'train' or 'test' about image * @return one row show a `picture` */ public static double[][] getImages(String fileName) { try{ return getImages(new FileInputStream(fileName)); }catch (FileNotFoundException e){ e.printStackTrace(); } return null; } public static double[][] getImages(InputStream inputStream) { double[][] x; try (BufferedInputStream bin = new BufferedInputStream(inputStream)) { byte[] bytes = new byte[4]; bin.read(bytes, 0, 4); if (!"00000803".equals(bytesToHex(bytes))) { // 读取魔数 throw new RuntimeException("Please select the correct file!"); } else { bin.read(bytes, 0, 4); int number = Integer.parseInt(bytesToHex(bytes), 16); // 读取样本总数 bin.read(bytes, 0, 4); int xPixel = Integer.parseInt(bytesToHex(bytes), 16); // 读取每行所含像素点数 bin.read(bytes, 0, 4); int yPixel = Integer.parseInt(bytesToHex(bytes), 16); // 读取每列所含像素点数 x = new double[number][xPixel * yPixel]; for (int i = 0; i < number; i++) { double[] element = new double[xPixel * yPixel]; for (int j = 0; j < xPixel * yPixel; j++) { element[j] = bin.read(); // 逐一读取像素值 // normalization // element[j] = bin.read() / 255.0; } x[i] = element; } } } catch (IOException e) { throw new RuntimeException(e); } return x; } /** * get labels of `train` or `test` * * @param fileName the file of 'train' or 'test' about label * @return lables */ public static double[] getLabels(String fileName) { try{ return getLabels(new FileInputStream(fileName)); }catch (FileNotFoundException e){ e.printStackTrace(); } return null; } public static double[] getLabels(InputStream inputStream) { double[] y; try (BufferedInputStream bin = new BufferedInputStream(inputStream)) { byte[] bytes = new byte[4]; bin.read(bytes, 0, 4); if (!"00000801".equals(bytesToHex(bytes))) { throw new RuntimeException("Please select the correct file!"); } else { bin.read(bytes, 0, 4); int number = Integer.parseInt(bytesToHex(bytes), 16); y = new double[number]; for (int i = 0; i < number; i++) { y[i] = bin.read(); } } } catch (IOException e) { throw new RuntimeException(e); } return y; } public static void drawGrayPicture(double[] pixelValues, String fileName) throws IOException { //double转int int[] res = new int[pixelValues.length]; for(int i=0;i
> 16; rgb[1] = (r & 0xff00) >> 8; rgb[2] = (r & 0xff); int color = 255 - (int)(rgb[0]* 0.3 + rgb[1] * 0.59 + rgb[2] * 0.11); result[i*height+j] = color; } } return result; } }
此工具类getImages方法实现mnist数据集转化为double[][]的类型,以train-images.idx3-ubyte为例,转化后的double[][]为double[60000][784],每个double[],代表着一张图,一共60000张图,每张图为28*28=784个特征值。getLabels方法实现mnist数据集标签转化为double[]类型。drawGrayPicture方法把double[]转化为一张人可观察的图。getSizeBlackWhiteImg方法实现把得到的一张图灰度化,并压缩成28*28的图,然后转化为double[]。
1.神经网络属性
private List
layers = new ArrayList<>(); private LastLayer lastLayer; private INDArray[] Network_W; private INDArray[] Network_B; private int nin; private int seed=123; private double learningrate=0.01; private int iteration = 10;
nin为数据输入维度 Layer为神经网络层的抽象 LastLayer 为最后一层,之所以分开是因为最后一层包含损失函数。Network_W
为每一层的W值矩阵,Network_B为每一层的B值,learningrate为学习率,iteration为迭代次数。seed为初始化随机数种子。
2.思路步骤:
(1)获取搭建神经元必要参数:输入数据维度、神经元层数、每层神经元个数。
public DeepNeuralNetWork(int nin){ this.nin = nin; } public DeepNeuralNetWork addLayer(Layer layer){ layers.add(layer); return this; } public DeepNeuralNetWork addLastLayer(LastLayer lastLayer){ this.lastLayer = lastLayer; Init(); return this; }
通过构造函数获取输入数据维度,通过加入layer层获取每层的信息。Layer是一个接口,从中可以获取到每层的激活函数、每层神经元个数等等。LastLayer 为最后一层类。与其他不同的是多了一个损失函数。得到最后一层之后就可以进行参数初始化。
(2)通过必要参数初始化神经网络。
public void Init(){ Network_W = new INDArray[layers.size()+1]; Network_B = new INDArray[layers.size()+1]; for(int i=0;i
0){ Network_W[i] = layers.get(i).getWinit().Init(seed,layers.get(i).getNeuralNumber(), nin); Network_B[i] = Nd4j.zeros(layers.get(i).getNeuralNumber(), 1); }else{ Network_W[i] = lastLayer.getWinit().Init(seed,lastLayer.getNeuralNumber(), nin); Network_B[i] = Nd4j.zeros(lastLayer.getNeuralNumber(), 1); } }else if(i==layers.size()){ //最后一个 Network_W[i] = lastLayer.getWinit().Init(seed,lastLayer.getNeuralNumber(), layers.get(i-1).getNeuralNumber()); Network_B[i] = Nd4j.zeros(lastLayer.getNeuralNumber(), 1); }else{ Network_W[i] = layers.get(i).getWinit().Init(seed,layers.get(i).getNeuralNumber(), layers.get(i-1).getNeuralNumber()); Network_B[i] = Nd4j.zeros(layers.get(i).getNeuralNumber(), 1); } } }
(3)向前传播
private INDArray linear_forward(INDArray A, INDArray W, INDArray b){ return W.mmul(A).addColumnVector(b); } private INDArray linear_activate_forward(INDArray A_p, INDArray W, INDArray b, ActivateMethod activate){ if(activate!=null){ return activate.activate_forward(linear_forward(A_p,W,b)); } return linear_forward(A_p,W,b); } private INDArray[] forward(INDArray X){ INDArray[] res = new INDArray[layers.size()+1]; INDArray P_A = X; for(int i=0;i
(4)反向传播
private INDArray LossBackward(INDArray A,INDArray Y){ return lastLayer.LastBackward(A,Y); } private double LossForward(INDArray A,INDArray Y){ return lastLayer.getLossMethod().LossForward(A,Y); } private List
backward(INDArray[] A_array,INDArray x,INDArray Y){ INDArray DZ = LossBackward(A_array[A_array.length-1],Y); INDArray[] DW = new INDArray[A_array.length]; INDArray[] DB = new INDArray[A_array.length]; List res = new ArrayList<>(); for(int i=A_array.length-1;i>=0;i--){ if(i==0){ //最后一次 INDArray dW = DZ.mmul(x.transpose()); INDArray dB = DZ.mmul(Nd4j.ones(x.shape()[1],1)); DW[i] = dW.div(x.shape()[1]); DB[i] = dB.div(x.shape()[1]); }else{ INDArray dW = DZ.mmul(A_array[i-1].transpose()); INDArray dB = DZ.mmul(Nd4j.ones(x.shape()[1],1)); DW[i] = dW.div(x.shape()[1]); DB[i] = dB.div(x.shape()[1]); DZ = activate_backward(Network_W[i].transpose().mmul(DZ),A_array[i-1],layers.get(i-1).getActivateMethod()); } } res.add(DW); res.add(DB); return res; }
(5)梯度下降
private void update_parameters(List
DW_DB){ INDArray[] DW = DW_DB.get(0); INDArray[] DB = DW_DB.get(1); for(int i=0;i
(6)loss值计算
private double LossForward(INDArray A,INDArray Y){ return lastLayer.getLossMethod().LossForward(A,Y); }
double loss = LossForward(A[A.length - 1],batch_list.get(j).getY());
(7)组合起来并训练
@Override public void train(TrainData data) { for(int i=0;i
batch_list = data.getBatchList(); for(int j=0;j DW_DB = backward(A, batch_list.get(j).getX(), batch_list.get(j).getY()); //反向传播 update_parameters(DW_DB); //梯度下降 double loss = LossForward(A[A.length - 1],batch_list.get(j).getY()); //打印情况 System.out.println("i=" + (i*batch_list.size()+j)); System.out.println("loss=" + loss); } } }
(8)对输入数据进行预测
@Override public INDArray predict(INDArray x) { INDArray[] A = forward(x); return A[A.length-1]; }
(9)其他辅助类
//数学工具类 public class MyMathUtil { public static double Epow(double x){ return Math.pow(Math.E, x);//e^x } public static INDArray Epow(INDArray value){ return FUN_IND(value,v->Epow(v)); } public static double Normalization(double value,double Max){ return value/Max; } public static double MaxValue(INDArray value){ if(value.shape()[0]>1){ double[][] s = value.toDoubleMatrix(); double my_Max=s[0][0]; for(double[] si:s){ for(double sj:si){ my_Max = Math.max(my_Max,sj); } } return my_Max; }else{ double[] s = value.toDoubleVector(); double my_Max=s[0]; for(double si:s){ my_Max = Math.max(my_Max,si); } return my_Max; } } public static INDArray Normalization(INDArray value){ double my_Max = MaxValue(value); return FUN_IND(value,s->s/my_Max); } public static INDArray indArraysubMax(INDArray value){ if(value!=null){ if(value.shape()[0]>1&&value.shape()[1]>1){ double[][] s = value.transpose().toDoubleMatrix(); for(int i=0;i
si?Max:si; } double[][] one_hot_res= new double[s.length][Max+1]; for(int i=0;i doubleFunction){ if(value!=null){ if(value.shape()[0]>1&&value.shape()[1]>1){ double[][] s = value.toDoubleMatrix(); for(int i=0;i MysigMoid(v)); } public static double Mytanh(double value) { double ex = Math.pow(Math.E, value);// e^x double ey = Math.pow(Math.E, -value);//e^(-x) double sinhx = ex-ey; double coshx = ex+ey; return sinhx/coshx; } public static INDArray Mytanh(INDArray value) { return FUN_IND(value,v->Mytanh(v)); } public static double relu(double value) { return Math.max(0,value); } public static INDArray relu(INDArray value) { return FUN_IND(value,v->relu(v)); } public static double relu_back(double value) { if(value>0){ return value; }else{ return 0; } } public static INDArray relu_back(INDArray value) { return FUN_IND(value,v->relu_back(v)); } public static double Log(double value) { return Math.log(value); } public static INDArray Log(INDArray value) { return FUN_IND(value,v->Log(v)); } public static INDArray sotfmax(INDArray A){ if(A!=null){ A = MyMathUtil.Epow(A); //A: 10,128 INDArray sum_A = Nd4j.ones(1,A.shape()[0]).mmul(A); //1,128 if(A.shape()[0]>1&&A.shape()[1]>1){ double[][] A_s = A.transpose().toDoubleMatrix(); //128 10 double[] SUM_A_s = sum_A.toDoubleVector(); for(int i=0;i
//激活方法类 public interface ActivateMethod { INDArray activate_forward(INDArray A); INDArray activate_backward(INDArray DA, INDArray A); }
public class Relu implements ActivateMethod { @Override public INDArray activate_forward(INDArray A) { return MyMathUtil.relu(A); } @Override public INDArray activate_backward(INDArray DA, INDArray A) { return MyMathUtil.relu_back(DA); } }
public class SoftMax implements ActivateMethod { @Override public INDArray activate_forward(INDArray A) { return MyMathUtil.sotfmax(MyMathUtil.indArraysubMax(A)); } @Override public INDArray activate_backward(INDArray DA, INDArray A) { return MyMathUtil.sotfmax_back(DA,A); } }
public class Tanh implements ActivateMethod { @Override public INDArray activate_forward(INDArray A) { return MyMathUtil.Mytanh(A); } @Override public INDArray activate_backward(INDArray DA, INDArray A) { return DA.mul(Nd4j.ones(A.shape()).sub(A.mul(A))); } }
//层抽象与实现
public interface Layer { int getNeuralNumber(); ActivateMethod getActivateMethod(); Layer setActivateMethod(ActivateMethod activate); Layer setWInit(Winit wInit); Winit getWinit(); }
public interface LastLayer extends Layer { LossMethod getLossMethod(); LastLayer setLossMethod(LossMethod lossMethod); default INDArray LastBackward(INDArray A,INDArray Y){ return getActivateMethod().activate_backward(getLossMethod().LossBackward(A,Y),A); } }
public class MyLayer implements Layer { private int number; private ActivateMethod activateMethod; private Winit winit; public MyLayer(int number,ActivateMethod method){ this.number = number; this.activateMethod = method; this.winit = new XAVIER(); } public MyLayer(int number,ActivateMethod method,Winit winit){ this.number = number; this.activateMethod = method; this.winit = winit; } @Override public int getNeuralNumber() { return number; } @Override public ActivateMethod getActivateMethod() { return activateMethod; } @Override public Layer setActivateMethod(ActivateMethod activate) { this.activateMethod = activate; return this; } @Override public Layer setWInit(Winit winit) { this.winit = winit; return this; } @Override public Winit getWinit() { return winit; } }
public class MyLastLayer extends MyLayer implements LastLayer { public MyLastLayer(int number,ActivateMethod method,LossMethod lossMethod){ super(number,method); this.lossMethod = lossMethod; } private LossMethod lossMethod; @Override public LossMethod getLossMethod() { return lossMethod; } @Override public LastLayer setLossMethod(LossMethod lossMethod) { this.lossMethod = lossMethod; return this; } }
//损失函数抽象与实现 public interface LossMethod { INDArray LossBackward(INDArray A,INDArray Y); double LossForward(INDArray A,INDArray Y); }
public class CrossEntropy implements LossMethod { @Override public INDArray LossBackward(INDArray A, INDArray Y) { return Nd4j.zeros(Y.div(A).shape()).sub(Y.div(A)); } @Override public double LossForward(INDArray A, INDArray Y) { INDArray los = Y.mul(MyMathUtil.Log(A)); return (0-los.sumNumber().doubleValue())/Y.shape()[1]; } }
public class MSE implements LossMethod { @Override public INDArray LossBackward(INDArray A, INDArray Y) { return A.sub(Y); } @Override public double LossForward(INDArray A, INDArray Y) { return (Y.sub(A).mmul(Y.sub(A).transpose()).sumNumber().doubleValue())/Y.shape()[1]; } }
//w值初始化
public interface Winit { INDArray Init(int seed,int out,int in); }
public class XAVIER implements Winit{ @Override public INDArray Init(int seed,int out, int in) { return Nd4j.randn(out,in,seed).muli(FastMath.sqrt(2.0 / (in+out))); } }
public class RELUWInit implements Winit { @Override public INDArray Init(int seed,int out, int in) { return Nd4j.randn(out,in,seed).muli(FastMath.sqrt(2.0 / in)); } }
public class RandWInit implements Winit { @Override public INDArray Init(int seed, int out, int in) { return Nd4j.rand(out,in,seed); } }
//训练数据类
public interface INData { INDArray getX(); INDArray getY(); int getSize(); }
public interface TrainData extends INData{ List
getBatchList(); } public class BatchData implements INData{ private INDArray x; private INDArray y; public BatchData(INDArray x,INDArray y){ this.x = x; this.y = y; } @Override public INDArray getX() { return x; } @Override public INDArray getY() { return y; } @Override public int getSize() { return getX().columns(); } public void setX(INDArray x) { this.x = x; } public void setY(INDArray y) { this.y = y; } }
public class MyTrainData implements TrainData { private INDArray x; private INDArray y; private int batch_size; public MyTrainData(INDArray x,INDArray y,int batch_size){ this.x = x.transpose(); this.y = y.transpose(); this.batch_size = batch_size; } public MyTrainData(INDArray x,INDArray y){ this.x = x.transpose(); this.y = y.transpose(); this.batch_size = -1; } @Override public INDArray getX() { return x; } @Override public INDArray getY() { return y; } @Override public List
getBatchList() { List res = new ArrayList<>(); shufflecard(); if(batch_size!=-1){ int lastColumnOrder = 0; for(int i=batch_size;i
3.创建神经网络并数据输入训练:
public static final ClassPathResource TRAIN_IMAGES_FILE = new ClassPathResource("data/train-images.idx3-ubyte"); public static final ClassPathResource TRAIN_LABELS_FILE = new ClassPathResource("data/train-labels.idx1-ubyte"); public static final ClassPathResource TEST_IMAGES_FILE = new ClassPathResource("data/t10k-images.idx3-ubyte"); public static final ClassPathResource TEST_LABELS_FILE = new ClassPathResource("data/t10k-labels.idx1-ubyte");
private model pointmodel = new DeepNeuralNetWork(28*28) .addLayer(new MyLayer(1000,new Tanh())) .addLayer(new MyLayer(500,new Tanh())) .addLayer(new MyLayer(100,new Tanh())) .addLastLayer(new SotfMaxCrossEntropyLastLayer(10)).setIteration(10).setLearningrate(0.06);
double[][] images = MnistReadUtil.getImages(TRAIN_IMAGES_FILE.getInputStream()); double[] labels = MnistReadUtil.getLabels(TRAIN_LABELS_FILE.getInputStream()); INDArray X = Nd4j.create(images); //60000,784 INDArray Y = Nd4j.create(labels).transpose(); //60000,1 INDArray X_I = MyMathUtil.Normalization(X); INDArray Y_I = MyMathUtil.ONEHOT(Y);//60000,10 TrainData data = new MyTrainData(X_I,Y_I,128); pointmodel.train(data);
4.测试模型
double[][] t_images = MnistReadUtil.getImages(TEST_IMAGES_FILE.getInputStream()); double[] t_labels = MnistReadUtil.getLabels(TEST_LABELS_FILE.getInputStream()); INDArray X_t = MyMathUtil.Normalization(Nd4j.create(t_images)); INDArray Y_t = MyMathUtil.ONEHOT(Nd4j.create(t_labels).transpose()); TrainData data_t = new MyTrainData(X_t,Y_t); INDArray X_P = pointmodel.predict(data_t.getX()); System.out.println("正确率:"+scord(X_P,data_t.getY())+"%");
//找出概率最大的值并与标签比较。
private float scord(INDArray value,INDArray Y) { int res = 0; int sum = 0; double[][] s = value.transpose().toDoubleMatrix(); double[][] Ys = Y.transpose().toDoubleMatrix(); for(int i=0;i
s[i][j]?order:j; Max = Max>s[i][j]?Max:s[i][j]; } if(order>0&&new Double(Ys[i][order]).intValue()==1){ res++; } sum++; } if(sum>0){ return ((float)res/sum)*100; }else{ return 0; } }
5.输入图片并预测
public String predict(@RequestParam(value = "file") MultipartFile file, ModelMap map){ if (file.isEmpty()) { System.out.println("文件为空空"); } try{ File my_file = File.createTempFile("tmp", null); file.transferTo(my_file); double[] m = MnistReadUtil.getSizeBlackWhiteImg(my_file,28,28); INDArray X_t = MyMathUtil.Normalization(Nd4j.create(m)); INDArray X_P = pointmodel.predict(X_t.transpose()); int number = getnumber(X_P); map.addAttribute ("number",number); return "freemarker/mnist/predict"; }catch (Exception e){ e.printStackTrace(); } return "freemarker/fail"; }
private int getnumber(INDArray X){ double[] s = X.toDoubleVector(); int res = 0; double Max = s[0]; for(int i=0;i
完整代码地址:[email protected]:woshiyigebing/my_dl4j.git