DL4J的神经网络输入处理DataSet介绍

Deeplearning4j的数据是由一个叫做DataSet的对象传入网络进行训练的,DataSet由四个主要元素组成,Features,Labels,FeaturesMask,LabelsMask,这四个元素都是INDArray,即是N维矩阵或者叫做N维张量。一般来说是2-4维矩阵,分别对应全联接网络、RNN网络、CNN网络的输入。

四个元素简单介绍如下:

  • Features 特征,特征可以是N维矩阵,以RNN举例,RNN的输入矩阵各维度的维数是[MiniBatch,FeaturesLength,TimeSeqLength],其中第二个维度就是特征列的个数,单个值[x,y,z]的含义就是MiniBatch的某一个批次x,时间序列TimeSeq的某一个时间点z,某一个Feature特征y的值。
  • Labels 标签,标签的维度需要和特征相对应,还是以RNN为例,标签的维度的维数就是[MiniBatch,LabelLength,TimeSeqLength],其中,如果网络是一个分类器网络的话,LabelLength是对应标签的独热处理,即LabelLength相当于分类数classes;如果网络是一个回归函数的话,那么LabelLength就是对应输出的几个回归目标值y的个数(一般是一个)。即是说,FeaturesLength相当于网络的InputSize;LabelLength相当于网络的OutputSize。输入输出的宽度。
  • FeaturesMask以及LabelsMask 特征掩模和标签掩模,如果需要掩盖某些数据的输入输出,即我们需要扔掉一些数据的输入或者输出,比方说RNN序列输出我只需要输出最后一个,或者输入我只需要前三个,那么这两个元素就有用了。以RNN为例,如果我的Label每个TimeSeq只输出最后一个时间点的值,那么LabelsMask就可以这么写,labelsMask的维度是[MiniBatch,TimeSeqLength],比方说是[x,y]当且仅当y = TimeSeqLength - 1 的时候[x,y] = 1,其余[x,y] = 0。这样就写好了一个输出的掩模。

关于RNN掩模的具体介绍可以看官网:

通过DL4J使用循环网络

Dataset的初始化方法源码:

/**
 * Creates a dataset with the specified input matrix and labels
 *
 * @param first  the feature matrix
 * @param second the labels (these should be binarized label matrices such that the specified label
 *               has a value of 1 in the desired column with the label)
 */
public DataSet(INDArray first, INDArray second) {
    this(first, second, null, null);
}

/**Create a dataset with the specified input INDArray and labels (output) INDArray, plus (optionally) mask arrays
 * for the features and labels
 * @param features Features (input)
 * @param labels Labels (output)
 * @param featuresMask Mask array for features, may be null
 * @param labelsMask Mask array for labels, may be null
 */
public DataSet(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask) {
    this.features = features;
    this.labels = labels;
    this.featuresMask = featuresMask;
    this.labelsMask = labelsMask;

    // we want this dataset to be fully committed to device
    Nd4j.getExecutioner().commit();
}

可以用getRange函数来截取一部分数据。

DataSet可以用merge函数竖向拼接。

可以用load函数来从流(或将文件转化为流)中读取数据。

可以用save函数来将DataSet存成流或者文件。Save和load函数会规定好用一个Byte的数据表示读取的一些属性,这样load的时候就能正确解析。

 

 

你可能感兴趣的:(深度学习数据预处理)