这篇主要介绍如dl4j如何操作csv,虽然实战中比较少用,但是对熟悉基本数据操作及结构还是有好处的,代码如下
public class BasicCSVClassifier { private static Logger log = LoggerFactory.getLogger(BasicCSVClassifier.class);//工厂方法生成日志类 private static Map,String> eats = readEnumCSV("/DataExamples/animals/eats.csv");//用readEnumCSV方法直接读csv,存到map private static Map ,String> sounds = readEnumCSV("/DataExamples/animals/sounds.csv"); private static Map ,String> classifiers = readEnumCSV("/DataExamples/animals/classifiers.csv"); public static void main(String[] args){ try { //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network//RecordReaderDataSetIterator把数据弄成DataSet对象,方便放入神经网络 int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row//iris每行5个值,4个特征后跟一个类别,4是类别索引 int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2//3个类别,标记为0,1,2 int batchSizeTraining = 30; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)//150个数据一次载入dataset,数据量大的时候不推荐,训练批的数量是30 DataSet trainingData = readCSVDataset( "/DataExamples/animals/animals_train.csv", batchSizeTraining, labelIndex, numClasses);//readCSVDataset方法直接读取csv变成DataSet数据 // this is the data we want to classify int batchSizeTest = 44;//测试批44,跟上面一样 DataSet testData = readCSVDataset("/DataExamples/animals/animals.csv", batchSizeTest, labelIndex, numClasses); // make the data model for records prior to normalization, because it // changes the data.//在规范化之前先构建数据结构,因为规范化改变了数据 Map ,Map ,Object>> animals = makeAnimalsForTesting(testData);//animals是这样的结构{0={eats=Mice, sounds=Meow, weight=10.0, yearsLived=19}, 1={eats=Cats, sounds=Bark, weight=60.0, yearsLived=9}...} //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance)://规范化数据,0均值,单位方差 DataNormalization normalizer = new NormalizerStandardize();//规范化器 normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data//计算训练数据的均值方差,通过
trainingData.getFeatures().mean(0)
trainingData.getFeatures().std(0)
可以获得每列的均值方差,注意这时只是收集
trainingData的统计信息,
trainingData本身没有变,如果执行
System.out.println(trainingData)会打印出一个包含属性的input数组和一个包含类别的output数组,而且是向量化的类别 [0.00, 0.00, 1.00]normalizer.transform(trainingData) ; //Apply normalization to the training data//规范化训练数据,这时
trainingData已经变成规范化的数据了normalizer.transform(testData) ; //Apply normalization to the test data. This is using statistics calculated from the *training* set//规范化测试数据
final int numInputs = 4;//4个属性 int outputNum = 3;//3个类别 int iterations = 1000;//迭代1000次 long seed = 6;//随机种子 log.info("Build model...."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()//以下套路和上一篇一致 .seed(seed) .iterations(iterations) .activation("tanh") .weightInit(WeightInit.XAVIER) .learningRate(0.1) .regularization(true).l2(1e-4) .list() .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3) .build()) .layer(1, new DenseLayer.Builder().nIn(3).nOut(3) .build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation("softmax") .nIn(3).nOut(outputNum).build()) .backprop(true).pretrain(false) .build(); //run the model MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(100)); model.fit(trainingData); //evaluate the model on the test set//用测试集评估模型 Evaluation eval = new Evaluation(3);//有3个类,所以输入3 INDArray output = model.output(testData.getFeatureMatrix());//根据测试数据属性预测类标签 eval.eval(testData.getLabels(), output);//评估器根据实际标签和预测标签进行评估 log.info(eval.stats());//打印评估信息 setFittedClassifiers(output, animals); logAnimals(animals);//把animals的values转成string打印 } catch (Exception e){ e.printStackTrace(); } } public static void logAnimals(Map,Map ,Object>> animals){ for(Map ,Object> a:animals.values()) log.info(a.toString()); } public static void setFittedClassifiers(INDArray output, Map ,Map ,Object>> animals){ for (int i = 0; i < output.rows() ; i++) {//为每行数据匹配一个分类,把标号变成类别名 // set the classification from the fitted results animals.get(i).put("classifier", classifiers.get(maxIndex(getFloatArrayFromSlice(output.slice(i)))));//调用了下面几个函数 } } /** * This method is to show how to convert the INDArray to a float array. This is to * provide some more examples on how to convert INDArray to types that are more java * centric.//这个方法展示了如何把INDArray转成java的小数数组 * * @param rowSlice * @return */ public static float[] getFloatArrayFromSlice(INDArray rowSlice){ float[] result = new float[rowSlice.columns()];//生成和rowSlice列长度一致的小数数组,并填充,记住这里的output是类别向量模式 for (int i = 0; i < rowSlice.columns(); i++) { result[i] = rowSlice.getFloat(i); } return result; } /** * find the maximum item index. This is used when the data is fitted and we * want to determine which class to assign the test row to * * @param vals * @return */ public static int maxIndex(float[] vals){//这个很简单了,因为output是类别向量模式,所以找出最大值即为预测分类 int maxIndex = 0; for (int i = 1; i < vals.length; i++){ float newnumber = vals[i]; if ((newnumber > vals[maxIndex])){ maxIndex = i; } } return maxIndex; } /** * take the dataset loaded for the matric and make the record model out of it so * we can correlate the fitted classifier to the record.//把dataset搞成这种结构主要是为了给没个数据匹配相应的分类名 * * @param testData * @return */ public static Map ,Map ,Object>> makeAnimalsForTesting(DataSet testData){ Map ,Map ,Object>> animals = new HashMap<>();//生成hashmap INDArray features = testData.getFeatureMatrix();//获取属性 for (int i = 0; i < features.rows() ; i++) {//遍历 INDArray slice = features.slice(i);//这里slice是把属性弄成一行一行的,再对每行进行处理 Map ,Object> animal = new HashMap();//生成hashmap //set the attributes//先填充animal,再用animal填充animals animal.put("yearsLived", slice.getInt(0)); animal.put("eats", eats.get(slice.getInt(1))); animal.put("sounds", sounds.get(slice.getInt(2))); animal.put("weight", slice.getFloat(3)); animals.put(i,animal); } return animals; } public static Map ,String> readEnumCSV(String csvFileClasspath) {//这个方法读取的结果map try{ List lines = IOUtils.readLines(new ClassPathResource(csvFileClasspath).getInputStream());//按行读取 Map ,String> enums = new HashMap<>();//生成hashmap for(String line:lines){//填充hashmap并返回 String[] parts = line.split(","); enums.put(Integer.parseInt(parts[0]),parts[1]); } return enums; } catch (Exception e){ e.printStackTrace(); return null; } } /** * used for testing and training * * @param csvFileClasspath * @param batchSize * @param labelIndex * @param numClasses * @return * @throws IOException * @throws InterruptedException */ private static DataSet readCSVDataset(//csv读取器 String csvFileClasspath, int batchSize, int labelIndex, int numClasses) throws IOException, InterruptedException{ RecordReader rr = new CSVRecordReader();//生成csv读取器 rr.initialize(new FileSplit(new ClassPathResource(csvFileClasspath).getFile()));//用文件初始化读取器 DataSetIterator iterator = new RecordReaderDataSetIterator(rr,batchSize,labelIndex,numClasses);//获取DataSet迭代器,传入csv读取器,批大小,类别索引,类别数 return iterator.next();//返回DataSet迭代器 } }