深度学习之Deeplearning4j

正式进入学习深度学习了,由于工作上面要使用Java做开发,因此我就选择了Deeplearning4j这个工具,个人感觉还是比较好用的,反正其他的DL工具我也没有使用过,可能这里谈到的更多是关于如何应用的,具体关于理论知识以后再去学习吧,在使用之前我先看了官网上面的例子,是一个关于使用CNN进行手写体识别的例子,跑完这个程序花了好长的时间(大约4个小时左右),可能是CPU没有充分利用的原因,代码读起来怪怪的,因为以前看过部分Mahout和MLlib的代码,里面矩阵操作要多一些,这个后面再去研究吧,现在主要做的是数据输入格式,因此我打算记录一下关于这方面的一些内容:
package org.deeplearning4j.datasets.mnist;


import java.io.FileNotFoundException;
import java.io.IOException;


/** * * MNIST database image file. Contains additional header information for the * number of rows and columns per each entry. * 图像文件 */
public class MnistImageFile extends MnistDbFile {
    private int rows;
    private int cols;

    /** * Creates new MNIST database image file ready for reading. * * @param name * the system-dependent filename * @param mode * the access mode * @throws IOException * @throws FileNotFoundException */
    public MnistImageFile(String name, String mode) throws  IOException {
        super(name, mode);

        // read header information
        // 后面的writeImageToPpm是先存行数,再存列数
        rows = readInt();
        cols = readInt();
    }

    /** * Reads the image at the current position. * * @return matrix representing the image * @throws IOException */
    public int[][] readImage() throws IOException {
        int[][] dat = new int[getRows()][getCols()]; //创建一个二维数组存储图像
        for (int i = 0; i < getCols(); i++) {
            for (int j = 0; j < getRows(); j++) {
                dat[i][j] = readUnsignedByte(); //读取一个字节转换为int类型
            }
        }
        return dat;
    }

    /** Read the specified number of images from the current position, to a byte[nImages][rows*cols] <@--@> 这个方法会直接返回数据集,关于这种数据的存储方式 * 可以参见官网上面卷积神经网络的教程,里面解释的也是比较清楚的(对于卷积,池化讲解的非常到位) * http://deeplearning4j.org/convolutionalnets.html * * Note that MNIST data set is stored as unsigned bytes; this method returns signed bytes without conversion * (i.e., same bits, but requires conversion before use) * @param nImages Number of images */
    public byte[][] readImagesUnsafe(int nImages) throws IOException{
        byte[][] out = new byte[nImages][0];
        for( int i=0; i<nImages; i++){
            out[i] = new byte[rows*cols];
            read(out[i]);
        }
        return out;
    }

    /** * Move the cursor to the next image. * 这个是对数据进行查找的方式 * @throws IOException */
    public void nextImage() throws IOException {
        super.next();
    }

    /** * Move the cursor to the previous image. * * @throws IOException */
    public void prevImage() throws IOException {
        super.prev();
    }

    @Override
    protected int getMagicNumber() {
        return 2051;
    }

    /** * Number of rows per image. * * @return int */
    public int getRows() {
        return rows;
    }

    /** * Number of columns per image. * * @return int */
    public int getCols() {
        return cols;
    }

    @Override
    public int getEntryLength() {
        return cols * rows;
    }

    @Override
    public int getHeaderSize() {
        return super.getHeaderSize() + 8; // to more integers - rows and columns
    }
}
package org.deeplearning4j.datasets.mnist;

import org.deeplearning4j.datasets.fetchers.MnistDataFetcher;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;

/**这个类用于读取书写体数据*/
public class MnistManager {
    MnistImageFile images;
    private MnistLabelFile labels;

    private byte[][] imagesArr;
    private int[] labelsArr; //每一个实例的标签
    private static final int HEADER_SIZE = 8;

    /** * Writes the given image in the given file using the PPM data format. * PPM是图像的一种存储格式,这个工具是将图像存储为PPM格式 * @param image 图像的矩阵表示 * @param ppmFileName 待存储文件的名称 * @throws IOException */
    public static void writeImageToPpm(int[][] image, String ppmFileName) throws IOException {
        try (BufferedWriter ppmOut = new BufferedWriter(new FileWriter(ppmFileName))) {
            int rows = image.length; //获取行数
            int cols = image[0].length; //获取列数
            ppmOut.write("P3\n");
            ppmOut.write("" + rows + " " + cols + " 255\n");//先存行数,再存列数
            for (int i = 0; i < rows; i++) {
                StringBuilder s = new StringBuilder();
                for (int j = 0; j < cols; j++) {
                // 这里为什么这么写我也不清楚,反正是不会错的,我才可能PPM存储的格式是RGB格式
                // 而现在的数据里面没有这种数据,因此这里就重复写了好几遍
                    s.append(image[i][j] + " " + image[i][j] + " " + image[i][j] + " ");
                }
                ppmOut.write(s.toString());
            }
        }

    }

    /** * Constructs an instance managing the two given data files. Supports * <code>NULL</code> value for one of the arguments in case reading only one * of the files (images and labels) is required. * * 根据图像的两个文件(图像本身以及其标签文件) * @param imagesFile * Can be <code>NULL</code>. In that case all future operations * using that file will fail. * @param labelsFile * Can be <code>NULL</code>. In that case all future operations * using that file will fail. * @throws IOException */
    public MnistManager(String imagesFile, String labelsFile, boolean train) throws IOException {
        if (imagesFile != null) {
            images = new MnistImageFile(imagesFile, "r");//图像文件
            if(train) imagesArr = new MnistImageFile(imagesFile, "r").readImagesUnsafe(MnistDataFetcher.NUM_EXAMPLES);//读取训练实例的数据
            else imagesArr = images.readImagesUnsafe(MnistDataFetcher.NUM_EXAMPLES_TEST);
        }
        if (labelsFile != null) {
            labels = new MnistLabelFile(labelsFile, "r");
            if(train) labelsArr = labels.readLabels(MnistDataFetcher.NUM_EXAMPLES); //读取训练实例的标签数据
            else labelsArr = labels.readLabels(MnistDataFetcher.NUM_EXAMPLES_TEST);
        }
        System.out.println();
    }

    public MnistManager(String imagesFile, String labelsFile) throws IOException{
        this(imagesFile,labelsFile,true);
    }

    /** * Reads the current image. * 根据当前位置读取一个图片 * @return matrix * @throws IOException */
    public int[][] readImage() throws IOException {
        if (images == null) {
            throw new IllegalStateException("Images file not initialized.");
        }
        return images.readImage();
    }

    public byte[] readImageUnsafe(int i){
        return imagesArr[i];
    }

    /** * Set the position to be read. * 设置读取位置的标记 * @param index */
    public void setCurrent(int index) {
        images.setCurrentIndex(index);
        labels.setCurrentIndex(index);
    }

    /** * Reads the current label. * * @return int * @throws IOException */
    public int readLabel() throws IOException {
        if (labels == null) {
            throw new IllegalStateException("labels file not initialized.");
        }
        return labels.readLabel();
    }

    public int readLabel(int i){
        return labelsArr[i];
    }

    /** * Get the underlying images file as {@link MnistImageFile}. * * @return {@link MnistImageFile}. */
    public MnistImageFile getImages() {
        return images;
    }

    /** * Get the underlying labels file as {@link MnistLabelFile}. * * @return {@link MnistLabelFile}. */
    public MnistLabelFile getLabels() {
        return labels;
    }

    /** * Close any resources opened by the manager. */
    public void close() {
        if(images != null) {
            try {
                images.close();
            } catch (IOException e) {}
            images = null;
        }
        if(labels != null) {
            try {
                labels.close();
            } catch (IOException e) {}
            labels = null;
        }
    }
}
package org.deeplearning4j.datasets.fetchers;
/** * 这个类的功能主要是获取手写体数据 */
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.base.MnistFetcher;
import org.deeplearning4j.datasets.mnist.MnistManager;
import org.deeplearning4j.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;


/** * Data fetcher for the MNIST dataset * @author Adam Gibson * */
public class MnistDataFetcher extends BaseDataFetcher {
    public static final int NUM_EXAMPLES = 60000; //允许的最大实例
    public static final int NUM_EXAMPLES_TEST = 10000; //允许的最大测试实例
    //不用管它这个是数据的存储路径
    protected static final String TEMP_ROOT = System.getProperty("user.home");
    //不用管它这个是数据的存储路径
    protected static final String MNIST_ROOT = TEMP_ROOT + File.separator + "MNIST" + File.separator;

    protected transient MnistManager man;
    //是否对数据进行二值化处理,即图像中近含有{0,1}两种。
    protected boolean binarize = true;
    //该数据是用于训练,还是用于测试
    protected boolean train;
    //用于随机排序记录实例的顺序
    protected int[] order;
    //产生随机数的一个种子
    protected Random rng;
    //是否进行随机排序
    protected boolean shuffle;

    /*构造函数,这里使用了系统生成的随机数的种子*/
    public MnistDataFetcher(boolean binarize) throws IOException {
        this(binarize,true,true,System.currentTimeMillis());
    }

    public MnistDataFetcher(boolean binarize, boolean train, boolean shuffle, long rngSeed) throws IOException {
        if(!mnistExists()) {
            new MnistFetcher().downloadAndUntar();//数据集不存在,要事先下载
        }
        String images; //图像数据的相关路径
        String labels; //图像数据的标签相关路径
        if(train){
            images = MNIST_ROOT + MnistFetcher.trainingFilesFilename_unzipped;
            labels = MNIST_ROOT + MnistFetcher.trainingFileLabelsFilename_unzipped;
            totalExamples = NUM_EXAMPLES;
        } else {
            images = MNIST_ROOT + MnistFetcher.testFilesFilename_unzipped;
            labels = MNIST_ROOT + MnistFetcher.testFileLabelsFilename_unzipped;
            totalExamples = NUM_EXAMPLES_TEST;
        }

        try {
            man = new MnistManager(images, labels, train);//这里才是关键<@--@>
        }catch(Exception e) {
            FileUtils.deleteDirectory(new File(MNIST_ROOT));
            new MnistFetcher().downloadAndUntar();
            man = new MnistManager(images, labels, train);
        }

        numOutcomes = 10; //实例的类别数目<手写体0~9>
        this.binarize = binarize;
        cursor = 0;
        inputColumns = man.getImages().getEntryLength(); //返回图片的尺寸大小 rows * cols
        this.train = train;
        this.shuffle = shuffle;

        if(train){
            order = new int[NUM_EXAMPLES];
        } else {
            order = new int[NUM_EXAMPLES_TEST];
        }
        for( int i=0; i<order.length; i++ ) order[i] = i;
        rng = new Random(rngSeed);
        reset();    //Shuffle order 随机排序
    }

    private boolean mnistExists(){ //判断数据是否已经下载?
        //Check 4 files:
        File f = new File(MNIST_ROOT,MnistFetcher.trainingFilesFilename_unzipped);
        if(!f.exists()) return false;
        f = new File(MNIST_ROOT,MnistFetcher.trainingFileLabelsFilename_unzipped);
        if(!f.exists()) return false;
        f = new File(MNIST_ROOT,MnistFetcher.testFilesFilename_unzipped);
        if(!f.exists()) return false;
        f = new File(MNIST_ROOT,MnistFetcher.testFileLabelsFilename_unzipped);
        if(!f.exists()) return false;
        return true;
    }

    public MnistDataFetcher() throws IOException {
        this(true);
    }

    @Override
    public void fetch(int numExamples) {
        if(!hasMore()) {
            throw new IllegalStateException("Unable to getFromOrigin more; there are no more images");
        }

        //we need to ensure that we don't overshoot the number of examples total
        List<DataSet> toConvert = new ArrayList<>(numExamples);
        for( int i=0; i<numExamples; i++, cursor++ ){
            if(!hasMore()) {
                break;
            }

            byte[] img = man.readImageUnsafe(order[cursor]); //一次性读取所有数据
            INDArray in = Nd4j.create(1, img.length);
            for( int j=0; j<img.length; j++ ){
                //二进制数据转换(有符号->无符号) 
                in.putScalar(j, ((int)img[j]) & 0xFF);  //byte is loaded as signed -> convert to unsigned
            }

            if(binarize) {//0-1二值化处理
                for(int d = 0; d < in.length(); d++) {
                    if(in.getDouble(d) > 30) {
                        in.putScalar(d,1);
                    }
                    else {
                        in.putScalar(d,0);
                    }
                }
            } else {
                in.divi(255);
            }

            INDArray out = createOutputVector(man.readLabel(order[cursor]));

            toConvert.add(new DataSet(in,out));
        }
        initializeCurrFromList(toConvert); //转换为最终的数据
    }

    @Override
    public void reset() {
        cursor = 0;
        curr = null;
        if(shuffle) MathUtils.shuffleArray(order, rng);
    }

    @Override
    public DataSet next() {
        DataSet next = super.next();
        return next;
    }

}

通过观察代码可以发现,数据很多部分是可以共享的,关键部分就是实现自己的fetch,对于不同的任务,我们都是可以采取不同的方式进行的。

你可能感兴趣的:(java,cnn,深度学习)