深度学习-读csv数据做分类器

这篇主要介绍如dl4j如何操作csv,虽然实战中比较少用,但是对熟悉基本数据操作及结构还是有好处的,代码如下

public class BasicCSVClassifier {

    private static Logger log = LoggerFactory.getLogger(BasicCSVClassifier.class);//工厂方法生成日志类

    private static Map,String> eats = readEnumCSV("/DataExamples/animals/eats.csv");//用readEnumCSV方法直接读csv,存到map
    private static Map,String> sounds = readEnumCSV("/DataExamples/animals/sounds.csv");
    private static Map,String> classifiers = readEnumCSV("/DataExamples/animals/classifiers.csv");

    public static void main(String[] args){

        try {

            //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network//RecordReaderDataSetIterator把数据弄成DataSet对象,方便放入神经网络

            int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row//iris每行5个值,4个特征后跟一个类别,4是类别索引
            int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2//3个类别,标记为0,1,2

            int batchSizeTraining = 30;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)//150个数据一次载入dataset,数据量大的时候不推荐,训练批的数量是30
            DataSet trainingData = readCSVDataset(
                    "/DataExamples/animals/animals_train.csv",
                    batchSizeTraining, labelIndex, numClasses);//readCSVDataset方法直接读取csv变成DataSet数据


            // this is the data we want to classify
            int batchSizeTest = 44;//测试批44,跟上面一样
            DataSet testData = readCSVDataset("/DataExamples/animals/animals.csv",
                    batchSizeTest, labelIndex, numClasses);


            // make the data model for records prior to normalization, because it
            // changes the data.//在规范化之前先构建数据结构,因为规范化改变了数据
            Map,Map,Object>> animals = makeAnimalsForTesting(testData);//animals是这样的结构{0={eats=Mice, sounds=Meow, weight=10.0, yearsLived=19}, 1={eats=Cats, sounds=Bark, weight=60.0, yearsLived=9}...}


            //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance)://规范化数据,0均值,单位方差
            DataNormalization normalizer = new NormalizerStandardize();//规范化器
            normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data//计算训练数据的均值方差,通过
trainingData.getFeatures().mean(0)
 
  
trainingData.getFeatures().std(0)
可以获得每列的均值方差,注意这时只是收集 trainingData的统计信息, trainingData本身没有变,如果执行
System.out.println(trainingData)会打印出一个包含属性的input数组和一个包含类别的output数组,而且是向量化的类别 [0.00, 0.00, 1.00]
normalizer.transform(trainingData) ; //Apply normalization to the training data//规范化训练数据,这时
trainingData已经变成规范化的数据了
normalizer.transform(testData) ; //Apply normalization to the test data. This is using statistics calculated from the *training* set//规范化测试数据
           final int numInputs = 4;//4个属性
            int outputNum = 3;//3个类别
            int iterations = 1000;//迭代1000            long seed = 6;//随机种子

            log.info("Build model....");
            MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()//以下套路和上一篇一致
                    .seed(seed)
                    .iterations(iterations)
                    .activation("tanh")
                    .weightInit(WeightInit.XAVIER)
                    .learningRate(0.1)
                    .regularization(true).l2(1e-4)
                    .list()
                    .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3)
                            .build())
                    .layer(1, new DenseLayer.Builder().nIn(3).nOut(3)
                            .build())
                    .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                            .activation("softmax")
                            .nIn(3).nOut(outputNum).build())
                    .backprop(true).pretrain(false)
                    .build();

            //run the model
            MultiLayerNetwork model = new MultiLayerNetwork(conf);
            model.init();
            model.setListeners(new ScoreIterationListener(100));

            model.fit(trainingData);

            //evaluate the model on the test set//用测试集评估模型
            Evaluation eval = new Evaluation(3);//3个类,所以输入3
            INDArray output = model.output(testData.getFeatureMatrix());//根据测试数据属性预测类标签

            eval.eval(testData.getLabels(), output);//评估器根据实际标签和预测标签进行评估
            log.info(eval.stats());//打印评估信息

            setFittedClassifiers(output, animals);
            logAnimals(animals);//animalsvalues转成string打印

        } catch (Exception e){
            e.printStackTrace();
        }

    }



    public static void logAnimals(Map,Map,Object>> animals){
        for(Map,Object> a:animals.values())
            log.info(a.toString());
    }

    public static void setFittedClassifiers(INDArray output, Map,Map,Object>> animals){
        for (int i = 0; i < output.rows() ; i++) {//为每行数据匹配一个分类,把标号变成类别名

            // set the classification from the fitted results
            animals.get(i).put("classifier",
                    classifiers.get(maxIndex(getFloatArrayFromSlice(output.slice(i)))));//调用了下面几个函数

        }

    }


    /**
     * This method is to show how to convert the INDArray to a float array. This is to
     * provide some more examples on how to convert INDArray to types that are more java
     * centric.//这个方法展示了如何把INDArray转成java的小数数组
     *
     * @param rowSlice
     * @return
     */
    public static float[] getFloatArrayFromSlice(INDArray rowSlice){
        float[] result = new float[rowSlice.columns()];//生成和rowSlice列长度一致的小数数组,并填充,记住这里的output是类别向量模式
        for (int i = 0; i < rowSlice.columns(); i++) {
            result[i] = rowSlice.getFloat(i);
        }
        return result;
    }

    /**
     * find the maximum item index. This is used when the data is fitted and we
     * want to determine which class to assign the test row to
     *
     * @param vals
     * @return
     */
    public static int maxIndex(float[] vals){//这个很简单了,因为output是类别向量模式,所以找出最大值即为预测分类
        int maxIndex = 0;
        for (int i = 1; i < vals.length; i++){
            float newnumber = vals[i];
            if ((newnumber > vals[maxIndex])){
                maxIndex = i;
            }
        }
        return maxIndex;
    }

    /**
     * take the dataset loaded for the matric and make the record model out of it so
     * we can correlate the fitted classifier to the record.//dataset搞成这种结构主要是为了给没个数据匹配相应的分类名
     *
     * @param testData
     * @return
     */
    public static Map,Map,Object>> makeAnimalsForTesting(DataSet testData){
        Map,Map,Object>> animals = new HashMap<>();//生成hashmap

        INDArray features = testData.getFeatureMatrix();//获取属性
        for (int i = 0; i < features.rows() ; i++) {//遍历
            INDArray slice = features.slice(i);//这里slice是把属性弄成一行一行的,再对每行进行处理
            Map,Object> animal = new HashMap();//生成hashmap

            //set the attributes//先填充animal,再用animal填充animals
            animal.put("yearsLived", slice.getInt(0));
            animal.put("eats", eats.get(slice.getInt(1)));
            animal.put("sounds", sounds.get(slice.getInt(2)));
            animal.put("weight", slice.getFloat(3));

            animals.put(i,animal);
        }
        return animals;

    }


    public static Map,String> readEnumCSV(String csvFileClasspath) {//这个方法读取的结果map
        try{
            List lines = IOUtils.readLines(new ClassPathResource(csvFileClasspath).getInputStream());//按行读取
            Map,String> enums = new HashMap<>();//生成hashmap
            for(String line:lines){//填充hashmap并返回
                String[] parts = line.split(",");
                enums.put(Integer.parseInt(parts[0]),parts[1]);
            }
            return enums;
        } catch (Exception e){
            e.printStackTrace();
            return null;
        }

    }

    /**
     * used for testing and training
     *
     * @param csvFileClasspath
     * @param batchSize
     * @param labelIndex
     * @param numClasses
     * @return
     * @throws IOException
     * @throws InterruptedException
     */
    private static DataSet readCSVDataset(//csv读取器
            String csvFileClasspath, int batchSize, int labelIndex, int numClasses)
            throws IOException, InterruptedException{

        RecordReader rr = new CSVRecordReader();//生成csv读取器
        rr.initialize(new FileSplit(new ClassPathResource(csvFileClasspath).getFile()));//用文件初始化读取器
        DataSetIterator iterator = new RecordReaderDataSetIterator(rr,batchSize,labelIndex,numClasses);//获取DataSet迭代器,传入csv读取器,批大小,类别索引,类别数
        return iterator.next();//返回DataSet迭代器
    }



}

你可能感兴趣的:(深度学习,deeplearning4j)