Weka中ROC和AUC通过调用API实现

Weka中实现计算ROC的是ThresholdCurve类的getCurve函数

(1)getCurve函数

@param 一般是十折交叉得到的predictions
@param 正例的类标记,多元分类问题,设置某个类别为正例,其他类别就是负例
@return datapoints

    if (pred.actual() == classIndex) {
      totPos += pred.weight();//累计正例权重
    } else {
      totNeg += pred.weight();//累计负例权重
    }
    Instances insts = makeHeader();
    此时得到insts为:
@relation ThresholdCurve
@attribute 'True Positives' numeric
@attribute 'False Negatives' numeric
@attribute 'False Positives' numeric
@attribute 'True Negatives' numeric
@attribute 'False Positive Rate' numeric
@attribute 'True Positive Rate' numeric
@attribute Precision numeric
@attribute Recall numeric
@attribute Fallout numeric
@attribute FMeasure numeric
@attribute 'Sample Size' numeric
@attribute Lift numeric
@attribute Threshold numeric
@data

    Instances insts = makeHeader();
    int[] sorted = Utils.sort(probs);//这里排序是为了把正例和负例各个放在两端,升序排列后
    //负例聚集到前端,正例聚集到尾端,sorted存储的是升序的Index而非元素本身
    TwoClassStats tc = new TwoClassStats(totPos, totNeg, 0, 0);
    double threshold = 0;
    double cumulativePos = 0;
    double cumulativeNeg = 0;

    for (int i = 0; i < sorted.length; i++) {
      if ((i == 0) || (probs[sorted[i]] > threshold)) {
        tc.setTruePositive(tc.getTruePositive() - cumulativePos);
        tc.setFalseNegative(tc.getFalseNegative() + cumulativePos);
        tc.setFalsePositive(tc.getFalsePositive() - cumulativeNeg);
        tc.setTrueNegative(tc.getTrueNegative() + cumulativeNeg);
        threshold = probs[sorted[i]];
        insts.add(makeInstance(tc, threshold));
        cumulativePos = 0;
        cumulativeNeg = 0;
        if (i == sorted.length - 1) {
          break;
        }
      }
      NominalPrediction pred = (NominalPrediction) predictions.get(sorted[i]);

      if (pred.actual() == classIndex) {
        cumulativePos += pred.weight();
      } else {
        cumulativeNeg += pred.weight();
      }
    }

    // make sure a zero point gets into the curve
    //确保0点在曲线上
    if (tc.getFalseNegative() != totPos || tc.getTrueNegative() != totNeg) {
      tc = new TwoClassStats(0, 0, totNeg, totPos);
      threshold = probs[sorted[sorted.length - 1]] + 10e-6;
      insts.add(makeInstance(tc, threshold));
    }

    return insts;

利用Weka画ROC和计算AUC的方法:
来自《数据挖掘与机器学习:WEKA应用技术与实践(第二版)》

    public static void test1()throws Exception{
        ArffLoader loader=new ArffLoader();
        loader.setSource(new File("./data/weather.nominal.arff"));
        Instances data=loader.getDataSet();
        data.setClassIndex(data.numAttributes()-1);
        Classifier classifier =new NaiveBayes();
        Evaluation eval=new Evaluation(data);
        eval.crossValidateModel(classifier, data, 10, new Random(1));

        ThresholdCurve tc=new ThresholdCurve();
        int classIndex=0;
        Instances curve =tc.getCurve(eval.predictions(),classIndex);

        PlotData2D plotdata=new PlotData2D(curve);
        plotdata.setPlotName(curve.relationName());
        plotdata.addInstanceNumberAttribute();

        ThresholdVisualizePanel tvp=new ThresholdVisualizePanel();
        tvp.setROCString("(Area under ROC=" +
                Utils.doubleToString(ThresholdCurve.getROCArea(curve), 4)+")");
        tvp.setName(curve.relationName());
        boolean [] cp=new boolean[curve.numInstances()];
        for(int i=0;i
            cp[i]=true;
        plotdata.setConnectPoints(cp);
        tvp.addPlot(plotdata);

        final JFrame jf=new JFrame("WEKA ROC: "+tvp.getName());
        jf.setSize(500,400);
        jf.getContentPane().setLayout(new BorderLayout());
        jf.getContentPane().add(tvp, BorderLayout.CENTER);
        jf.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
        jf.setVisible(true);

    }

Weka中ROC和AUC通过调用API实现_第1张图片
注意如果把交叉验证的Random(1)改为Random(1234)(书中的代码)的话得到的图稍有不同:
Weka中ROC和AUC通过调用API实现_第2张图片

另一个类似的方法也可参考:

    public static void test2() throws Exception {
        ArffLoader loader=new ArffLoader();
        loader.setSource(new File("./data/weather.nominal.arff"));
        Instances data=loader.getDataSet();
        data.setClassIndex(data.numAttributes() - 1);
        /*
         * 训练分类器并用十字交叉验证法来获得Evaluation对象
         * 注意这里的方法与我们在上几节中使用的验证法是不同。
         */
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));
        /*
         * 生成用于得到ROC曲面和AUC值的Instances对象
         * 顺带打印了一些其它信息,用于在SPSS中生成ROC曲面
         * 如果我们查看weka源码就会知道这个Instances对象包含了很多分类的结果信息
         * 例如:FMeasure、Recall、Precision、True Positive Rate、
         * False Positive Rate等等。我们可以用这些信息绘制各种曲面。
         */
        ThresholdCurve tc = new ThresholdCurve();
        // classIndex is the index of the class to consider as "positive"
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);
        System.out.println("The area under the ROC curve: " + eval.areaUnderROC(classIndex));
        /*
         * 在这里我们通过结果信息Instances对象得到包含TP、FP的两个数组
         * 这两个数组用于在SPSS中通过线图绘制ROC曲面
         */
        int tpIndex = result.attribute(ThresholdCurve.TP_RATE_NAME).index();
        int fpIndex = result.attribute(ThresholdCurve.FP_RATE_NAME).index();
        double[] tpRate = result.attributeToDoubleArray(tpIndex);
        double[] fpRate = result.attributeToDoubleArray(fpIndex);
        System.out.println("TPRate "+Arrays.toString(tpRate));
        System.out.println("FPRate "+Arrays.toString(fpRate));

        /*
         * 4.使用结果信息instances对象来显示ROC曲面
         */
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        // 这个获得AUC的方式与上面的不同,其实得到的都是一个共同的结果
        vmc.setROCString("(Area under ROC = " +
                Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        boolean [] cp=new boolean[result.numInstances()];
        for(int i=0;i
            cp[i]=true;
        tempd.setConnectPoints(cp);
        vmc.addPlot(tempd);
        // 显示曲面
        String plotName = vmc.getName();
        final javax.swing.JFrame jf =
                new javax.swing.JFrame("Weka Classifier Visualize: " + plotName);
        jf.setSize(500, 400);
        jf.getContentPane().setLayout(new BorderLayout());
        jf.getContentPane().add(vmc, BorderLayout.CENTER);
        jf.addWindowListener(new java.awt.event.WindowAdapter() {
            public void windowClosing(java.awt.event.WindowEvent e) {
                jf.dispose();
            }
        });
        jf.setVisible(true);

    }

你可能感兴趣的:(Weka)