正式进入学习深度学习了,由于工作上面要使用Java做开发,因此我就选择了Deeplearning4j这个工具,个人感觉还是比较好用的,反正其他的DL工具我也没有使用过,可能这里谈到的更多是关于如何应用的,具体关于理论知识以后再去学习吧,在使用之前我先看了官网上面的例子,是一个关于使用CNN进行手写体识别的例子,跑完这个程序花了好长的时间(大约4个小时左右),可能是CPU没有充分利用的原因,代码读起来怪怪的,因为以前看过部分Mahout和MLlib的代码,里面矩阵操作要多一些,这个后面再去研究吧,现在主要做的是数据输入格式,因此我打算记录一下关于这方面的一些内容:
package org.deeplearning4j.datasets.mnist;
import java.io.FileNotFoundException;
import java.io.IOException;
/** * * MNIST database image file. Contains additional header information for the * number of rows and columns per each entry. * 图像文件 */
public class MnistImageFile extends MnistDbFile {
private int rows;
private int cols;
/** * Creates new MNIST database image file ready for reading. * * @param name * the system-dependent filename * @param mode * the access mode * @throws IOException * @throws FileNotFoundException */
public MnistImageFile(String name, String mode) throws IOException {
super(name, mode);
// read header information
// 后面的writeImageToPpm是先存行数,再存列数
rows = readInt();
cols = readInt();
}
/** * Reads the image at the current position. * * @return matrix representing the image * @throws IOException */
public int[][] readImage() throws IOException {
int[][] dat = new int[getRows()][getCols()]; //创建一个二维数组存储图像
for (int i = 0; i < getCols(); i++) {
for (int j = 0; j < getRows(); j++) {
dat[i][j] = readUnsignedByte(); //读取一个字节转换为int类型
}
}
return dat;
}
/** Read the specified number of images from the current position, to a byte[nImages][rows*cols] <@--@> 这个方法会直接返回数据集,关于这种数据的存储方式 * 可以参见官网上面卷积神经网络的教程,里面解释的也是比较清楚的(对于卷积,池化讲解的非常到位) * http://deeplearning4j.org/convolutionalnets.html * * Note that MNIST data set is stored as unsigned bytes; this method returns signed bytes without conversion * (i.e., same bits, but requires conversion before use) * @param nImages Number of images */
public byte[][] readImagesUnsafe(int nImages) throws IOException{
byte[][] out = new byte[nImages][0];
for( int i=0; i<nImages; i++){
out[i] = new byte[rows*cols];
read(out[i]);
}
return out;
}
/** * Move the cursor to the next image. * 这个是对数据进行查找的方式 * @throws IOException */
public void nextImage() throws IOException {
super.next();
}
/** * Move the cursor to the previous image. * * @throws IOException */
public void prevImage() throws IOException {
super.prev();
}
@Override
protected int getMagicNumber() {
return 2051;
}
/** * Number of rows per image. * * @return int */
public int getRows() {
return rows;
}
/** * Number of columns per image. * * @return int */
public int getCols() {
return cols;
}
@Override
public int getEntryLength() {
return cols * rows;
}
@Override
public int getHeaderSize() {
return super.getHeaderSize() + 8; // to more integers - rows and columns
}
}
package org.deeplearning4j.datasets.mnist;
import org.deeplearning4j.datasets.fetchers.MnistDataFetcher;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
/**这个类用于读取书写体数据*/
public class MnistManager {
MnistImageFile images;
private MnistLabelFile labels;
private byte[][] imagesArr;
private int[] labelsArr; //每一个实例的标签
private static final int HEADER_SIZE = 8;
/** * Writes the given image in the given file using the PPM data format. * PPM是图像的一种存储格式,这个工具是将图像存储为PPM格式 * @param image 图像的矩阵表示 * @param ppmFileName 待存储文件的名称 * @throws IOException */
public static void writeImageToPpm(int[][] image, String ppmFileName) throws IOException {
try (BufferedWriter ppmOut = new BufferedWriter(new FileWriter(ppmFileName))) {
int rows = image.length; //获取行数
int cols = image[0].length; //获取列数
ppmOut.write("P3\n");
ppmOut.write("" + rows + " " + cols + " 255\n");//先存行数,再存列数
for (int i = 0; i < rows; i++) {
StringBuilder s = new StringBuilder();
for (int j = 0; j < cols; j++) {
// 这里为什么这么写我也不清楚,反正是不会错的,我才可能PPM存储的格式是RGB格式
// 而现在的数据里面没有这种数据,因此这里就重复写了好几遍
s.append(image[i][j] + " " + image[i][j] + " " + image[i][j] + " ");
}
ppmOut.write(s.toString());
}
}
}
/** * Constructs an instance managing the two given data files. Supports * <code>NULL</code> value for one of the arguments in case reading only one * of the files (images and labels) is required. * * 根据图像的两个文件(图像本身以及其标签文件) * @param imagesFile * Can be <code>NULL</code>. In that case all future operations * using that file will fail. * @param labelsFile * Can be <code>NULL</code>. In that case all future operations * using that file will fail. * @throws IOException */
public MnistManager(String imagesFile, String labelsFile, boolean train) throws IOException {
if (imagesFile != null) {
images = new MnistImageFile(imagesFile, "r");//图像文件
if(train) imagesArr = new MnistImageFile(imagesFile, "r").readImagesUnsafe(MnistDataFetcher.NUM_EXAMPLES);//读取训练实例的数据
else imagesArr = images.readImagesUnsafe(MnistDataFetcher.NUM_EXAMPLES_TEST);
}
if (labelsFile != null) {
labels = new MnistLabelFile(labelsFile, "r");
if(train) labelsArr = labels.readLabels(MnistDataFetcher.NUM_EXAMPLES); //读取训练实例的标签数据
else labelsArr = labels.readLabels(MnistDataFetcher.NUM_EXAMPLES_TEST);
}
System.out.println();
}
public MnistManager(String imagesFile, String labelsFile) throws IOException{
this(imagesFile,labelsFile,true);
}
/** * Reads the current image. * 根据当前位置读取一个图片 * @return matrix * @throws IOException */
public int[][] readImage() throws IOException {
if (images == null) {
throw new IllegalStateException("Images file not initialized.");
}
return images.readImage();
}
public byte[] readImageUnsafe(int i){
return imagesArr[i];
}
/** * Set the position to be read. * 设置读取位置的标记 * @param index */
public void setCurrent(int index) {
images.setCurrentIndex(index);
labels.setCurrentIndex(index);
}
/** * Reads the current label. * * @return int * @throws IOException */
public int readLabel() throws IOException {
if (labels == null) {
throw new IllegalStateException("labels file not initialized.");
}
return labels.readLabel();
}
public int readLabel(int i){
return labelsArr[i];
}
/** * Get the underlying images file as {@link MnistImageFile}. * * @return {@link MnistImageFile}. */
public MnistImageFile getImages() {
return images;
}
/** * Get the underlying labels file as {@link MnistLabelFile}. * * @return {@link MnistLabelFile}. */
public MnistLabelFile getLabels() {
return labels;
}
/** * Close any resources opened by the manager. */
public void close() {
if(images != null) {
try {
images.close();
} catch (IOException e) {}
images = null;
}
if(labels != null) {
try {
labels.close();
} catch (IOException e) {}
labels = null;
}
}
}
package org.deeplearning4j.datasets.fetchers;
/** * 这个类的功能主要是获取手写体数据 */
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.base.MnistFetcher;
import org.deeplearning4j.datasets.mnist.MnistManager;
import org.deeplearning4j.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
/** * Data fetcher for the MNIST dataset * @author Adam Gibson * */
public class MnistDataFetcher extends BaseDataFetcher {
public static final int NUM_EXAMPLES = 60000; //允许的最大实例
public static final int NUM_EXAMPLES_TEST = 10000; //允许的最大测试实例
//不用管它这个是数据的存储路径
protected static final String TEMP_ROOT = System.getProperty("user.home");
//不用管它这个是数据的存储路径
protected static final String MNIST_ROOT = TEMP_ROOT + File.separator + "MNIST" + File.separator;
protected transient MnistManager man;
//是否对数据进行二值化处理,即图像中近含有{0,1}两种。
protected boolean binarize = true;
//该数据是用于训练,还是用于测试
protected boolean train;
//用于随机排序记录实例的顺序
protected int[] order;
//产生随机数的一个种子
protected Random rng;
//是否进行随机排序
protected boolean shuffle;
/*构造函数,这里使用了系统生成的随机数的种子*/
public MnistDataFetcher(boolean binarize) throws IOException {
this(binarize,true,true,System.currentTimeMillis());
}
public MnistDataFetcher(boolean binarize, boolean train, boolean shuffle, long rngSeed) throws IOException {
if(!mnistExists()) {
new MnistFetcher().downloadAndUntar();//数据集不存在,要事先下载
}
String images; //图像数据的相关路径
String labels; //图像数据的标签相关路径
if(train){
images = MNIST_ROOT + MnistFetcher.trainingFilesFilename_unzipped;
labels = MNIST_ROOT + MnistFetcher.trainingFileLabelsFilename_unzipped;
totalExamples = NUM_EXAMPLES;
} else {
images = MNIST_ROOT + MnistFetcher.testFilesFilename_unzipped;
labels = MNIST_ROOT + MnistFetcher.testFileLabelsFilename_unzipped;
totalExamples = NUM_EXAMPLES_TEST;
}
try {
man = new MnistManager(images, labels, train);//这里才是关键<@--@>
}catch(Exception e) {
FileUtils.deleteDirectory(new File(MNIST_ROOT));
new MnistFetcher().downloadAndUntar();
man = new MnistManager(images, labels, train);
}
numOutcomes = 10; //实例的类别数目<手写体0~9>
this.binarize = binarize;
cursor = 0;
inputColumns = man.getImages().getEntryLength(); //返回图片的尺寸大小 rows * cols
this.train = train;
this.shuffle = shuffle;
if(train){
order = new int[NUM_EXAMPLES];
} else {
order = new int[NUM_EXAMPLES_TEST];
}
for( int i=0; i<order.length; i++ ) order[i] = i;
rng = new Random(rngSeed);
reset(); //Shuffle order 随机排序
}
private boolean mnistExists(){ //判断数据是否已经下载?
//Check 4 files:
File f = new File(MNIST_ROOT,MnistFetcher.trainingFilesFilename_unzipped);
if(!f.exists()) return false;
f = new File(MNIST_ROOT,MnistFetcher.trainingFileLabelsFilename_unzipped);
if(!f.exists()) return false;
f = new File(MNIST_ROOT,MnistFetcher.testFilesFilename_unzipped);
if(!f.exists()) return false;
f = new File(MNIST_ROOT,MnistFetcher.testFileLabelsFilename_unzipped);
if(!f.exists()) return false;
return true;
}
public MnistDataFetcher() throws IOException {
this(true);
}
@Override
public void fetch(int numExamples) {
if(!hasMore()) {
throw new IllegalStateException("Unable to getFromOrigin more; there are no more images");
}
//we need to ensure that we don't overshoot the number of examples total
List<DataSet> toConvert = new ArrayList<>(numExamples);
for( int i=0; i<numExamples; i++, cursor++ ){
if(!hasMore()) {
break;
}
byte[] img = man.readImageUnsafe(order[cursor]); //一次性读取所有数据
INDArray in = Nd4j.create(1, img.length);
for( int j=0; j<img.length; j++ ){
//二进制数据转换(有符号->无符号)
in.putScalar(j, ((int)img[j]) & 0xFF); //byte is loaded as signed -> convert to unsigned
}
if(binarize) {//0-1二值化处理
for(int d = 0; d < in.length(); d++) {
if(in.getDouble(d) > 30) {
in.putScalar(d,1);
}
else {
in.putScalar(d,0);
}
}
} else {
in.divi(255);
}
INDArray out = createOutputVector(man.readLabel(order[cursor]));
toConvert.add(new DataSet(in,out));
}
initializeCurrFromList(toConvert); //转换为最终的数据
}
@Override
public void reset() {
cursor = 0;
curr = null;
if(shuffle) MathUtils.shuffleArray(order, rng);
}
@Override
public DataSet next() {
DataSet next = super.next();
return next;
}
}
通过观察代码可以发现,数据很多部分是可以共享的,关键部分就是实现自己的fetch,对于不同的任务,我们都是可以采取不同的方式进行的。