Weka高级应用--Java API

1.引入

本文是我学习《数据挖掘与机器学习–WEKA应用技术与实践》的笔记。该书电子版的链接是:http://download.csdn.net/detail/fhb292262794/8759397
前一篇博文总结了用Weka演示机器学习的算法处理,主要是通过Weka3.8的客户端软件操作。
本文通过Java API调用来处理,这样就可以在编程中应用Weka的机器学习算法处理数据。

本书的实例是用weka3.7,我下载使用的是最新版的weka3.8,更新代码以适应了weka3.8后,整理记录如下。

1.分类(手把手教你写代码)

1.1 线性回归

预测房价

房价数据:

@RELATION house

@ATTRIBUTE houseSize NUMERIC
@ATTRIBUTE lotSize NUMERIC
@ATTRIBUTE bedrooms NUMERIC
@ATTRIBUTE granite NUMERIC
@ATTRIBUTE bathroom NUMERIC
@ATTRIBUTE sellingPrice NUMERIC

@DATA
3529,9191,6,0,0,205000 
3247,10061,5,1,1,224900 
4032,10150,5,0,1,197900 
2397,14156,4,1,0,189900 
2200,9600,4,0,1,195000 
3536,19994,6,1,1,325000 
2983,9365,5,0,1,230000 

需求:根据该地区附近的房子信息及房价出售价格,预测新的房子的售价。该房子信息是:houseSize:3198;lotSize:9669;bedrooms:5;granite:3;bathroom:1;请预测房价。

由需求可知,需要处理的逻辑是:
1.加载房价数据。
2.设置属性信息。
3.构建分类器并计算系数。
4.使用回归系数预测未知房价。

代码如下:

    public static final String WEKA_PATH = "data/weka/";
    public static final String WEATHER_NOMINAL_PATH = "data/weka/weather.nominal.arff";
    public static final String WEATHER_NUMERIC_PATH = "data/weka/weather.numeric.arff";
    public static final String SEGMENT_CHALLENGE_PATH = "data/weka/segment-challenge.arff";
    public static final String SEGMENT_TEST_PATH = "data/weka/segment-test.arff";
    public static final String IONOSPHERE_PATH = "data/weka/ionosphere.arff";

    public static void pln(String str) {
        System.out.println(str);
    }

    @Test
    public void testLinearRegression() throws Exception {
        Instances dataset = ConverterUtils.DataSource.read(WEKA_PATH + "houses.arff");
        dataset.setClassIndex(dataset.numAttributes() - 1);
        LinearRegression linearRegression = new LinearRegression();
        try {
            linearRegression.buildClassifier(dataset);
        } catch (Exception e) {
            e.printStackTrace();
        }
        double[] coef = linearRegression.coefficients();
        double myHouseValue = (coef[0] * 3198) +
                (coef[1] * 9669) +
                (coef[2] * 5) +
                (coef[3] * 3) +
                (coef[4] * 1) +
                coef[6];

        System.out.println(myHouseValue);
    }

1.2 随机森林

代码:

@Test
    public void testRandomForestClassifier() throws Exception {
        ArffLoader loader = new ArffLoader();
        loader.setFile(new File(WEKA_PATH + "segment-challenge.arff"));
        Instances instances = loader.getDataSet();
        instances.setClassIndex(instances.numAttributes() - 1);
        System.out.println(instances);
        System.out.println("------------");

        RandomForest rf = new RandomForest();
        rf.buildClassifier(instances);
        System.out.println(rf);
    }

1.3 元分类器

// 元分类器
    @Test
    public void testMetaClassifier() throws Exception {
        Instances data = ConverterUtils.DataSource.read(WEATHER_NUMERIC_PATH);
        if (data.classIndex() == -1)
            data.setClassIndex(data.numAttributes() - 1);

        AttributeSelectedClassifier classifier = new AttributeSelectedClassifier();
        CfsSubsetEval eval = new CfsSubsetEval();
        GreedyStepwise stepwise = new GreedyStepwise();
        stepwise.setSearchBackwards(true);
        J48 base = new J48();
        classifier.setClassifier(base);
        classifier.setEvaluator(eval);
        classifier.setSearch(stepwise);
        Evaluation evaluation = new Evaluation(data);
        evaluation.crossValidateModel(classifier, data, 10, new Random(1234));
        pln(evaluation.toSummaryString());
    }

1.4 预测分类结果(批量处理)

代码:

 /**
     * 利用训练集预测测试集的分类,批量处理
     */
    @Test
    public void testOutputClassDistribution() throws Exception {
        ArffLoader loader = new ArffLoader();
        loader.setFile(new File(SEGMENT_CHALLENGE_PATH));
        Instances train = loader.getDataSet();
        train.setClassIndex(train.numAttributes() - 1);

        ArffLoader loader1 = new ArffLoader();
        loader1.setFile(new File(SEGMENT_TEST_PATH));
        Instances test = loader1.getDataSet();
        test.setClassIndex(test.numAttributes() - 1);

        J48 classifier = new J48();
        classifier.buildClassifier(train);
        System.out.println("num\t-\tfact\t-\tpred\t-\terr\t-\tdistribution");
        for (int i = 0; i < test.numInstances(); i++) {
            double pred = classifier.classifyInstance(test.instance(i));
            double[] dist = classifier.distributionForInstance(test.instance(i));
            StringBuilder sb = new StringBuilder();
            sb.append(i + 1)
                    .append(" - ")
                    .append(test.instance(i).toString(test.classIndex()))
                    .append(" - ")
                    .append(test.classAttribute().value((int) pred))
                    .append(" - ");
            if (pred != test.instance(i).classValue())
                sb.append("yes");
            else
                sb.append("no");
            sb.append(" - ");
            sb.append(Utils.arrayToString(dist));
            System.out.println(sb.toString());
        }
    }

这里指定的是J48,是决策树分类器,可以用其他更好的分类器替代,请比较效果选用分类器。

1.5 交叉验证

代码:

// 交叉验证并预测
    @Test
    public void testOnceCVAndPrediction() throws Exception {
        Instances data = ConverterUtils.DataSource.read(IONOSPHERE_PATH);
        data.setClassIndex(data.numAttributes() - 1);
        Classifier classifier = new J48();
        int seed = 1234;
        int folds = 10;

        Debug.Random random = new Debug.Random(seed);
        Instances newData = new Instances(data);
        newData.randomize(random);
        if (newData.classAttribute().isNominal())
            newData.stratify(folds);

        // 执行交叉验证,并添加预测
        Instances predictedData = null;
        Evaluation eval = new Evaluation(newData);
        for (int i = 0; i < folds; i++) {
            Instances train = newData.trainCV(folds, i);
            Instances test = newData.testCV(folds, i);
            Classifier clsCopy = AbstractClassifier.makeCopy(classifier);
            clsCopy.buildClassifier(train);
            eval.evaluateModel(clsCopy, test);

            // add prediction
            AddClassification filter = new AddClassification();
            filter.setClassifier(classifier);
            filter.setOutputClassification(true);
            filter.setOutputDistribution(true);
            filter.setOutputErrorFlag(true);
            filter.setInputFormat(train);
            Filter.useFilter(train, filter);
            Instances pred = Filter.useFilter(test, filter);
            if (predictedData == null)
                predictedData = new Instances(pred, 0);
            for (int j = 0; j < pred.numInstances(); j++)
                predictedData.add(pred.instance(j));
        }
        pln("classifier:" + classifier.getClass().getName() + " " + Utils.joinOptions(((OptionHandler) classifier).getOptions()));
        pln("data:" + data.relationName());
        pln("seed:" + seed);
        pln(eval.toSummaryString("=== " + folds + " test ===", false));
        // write data
        ConverterUtils.DataSink.write(WEKA_PATH + "predictions.arff", predictedData);
    }

2.聚类(手把手教你写代码)

2.1 EM

@Test
    public void testEM() throws Exception {
        Instances instances = ConverterUtils.DataSource.read(WEKA_PATH + "contact-lenses.arff");
        EM cluster = new EM();
        cluster.setOptions(new String[]{"-I", "100"});
        cluster.buildClusterer(instances);
        pln(cluster.toString());
    }

2.2 估计聚类器

// 评估聚类器的方式 3种
    @Test
    public void testEvaluation() throws Exception {
        String filePath = WEKA_PATH + "contact-lenses.arff";
        Instances instances = ConverterUtils.DataSource.read(filePath);
        // 第1种
        String[] options = new String[]{"-t", filePath};
        String output = ClusterEvaluation.evaluateClusterer(new EM(), options);
        pln(output);

        // 第2种
        DensityBasedClusterer dbc = new EM();
        dbc.buildClusterer(instances);
        ClusterEvaluation clusterEvaluation = new ClusterEvaluation();
        clusterEvaluation.setClusterer(dbc);
        clusterEvaluation.evaluateClusterer(new Instances(instances));
        pln(clusterEvaluation.clusterResultsToString());

        // 第3种
        // 基于密度的聚类器交叉验证
        DensityBasedClusterer newdbc = new EM();
        double logLikelyhood = ClusterEvaluation.crossValidateModel(newdbc, instances, 10, instances.getRandomNumberGenerator(1234));
        pln("logLikelyhood: " + logLikelyhood);
    }

2.3 聚类并评估

@Test
    public void testClassesToClusters() throws Exception {
        String filePath = WEKA_PATH + "contact-lenses.arff";
        Instances data = ConverterUtils.DataSource.read(filePath);
        data.setClassIndex(data.numAttributes() - 1);
        Remove remove = new Remove();
        remove.setAttributeIndices("" + (data.classIndex() + 1));
        remove.setInputFormat(data);
        Instances dataCluster = Filter.useFilter(data, remove);

        Clusterer cluster = new EM();
        cluster.buildClusterer(dataCluster);

        ClusterEvaluation eval = new ClusterEvaluation();
        eval.setClusterer(cluster);
        eval.evaluateClusterer(data);

        pln(eval.clusterResultsToString());
    }

2.4 输出聚类点

@Test
    public void testOutputClusterDistribution() throws Exception {
        Instances train = ConverterUtils.DataSource.read(SEGMENT_CHALLENGE_PATH);
        Instances test = ConverterUtils.DataSource.read(SEGMENT_TEST_PATH);
        if (!train.equalHeaders(test))
            throw new Exception("train data and test data not the same headers.");

        EM clusterer = new EM();
        clusterer.buildClusterer(train);
        pln("id - cluster - distribution");
        for (int i = 0; i < test.numInstances(); i++) {
            int cluster = clusterer.clusterInstance(test.instance(i));
            double[] dists = clusterer.distributionForInstance(test.instance(i));
            StringBuilder sb = new StringBuilder();
            sb.append(i + 1).append(" - ").append(cluster).append(" - ").append(Utils.arrayToString(dists));
            pln(sb.toString());
        }
    }

3.属性选择(手把手教你写代码)

自动属性选择

应用CfsSubsetEval及GreedyStepwise处理:

// 底层API属性选择
    @Test
    public void testUseLowApi() throws Exception {
        ConverterUtils.DataSource source = new ConverterUtils.DataSource(WEATHER_NOMINAL_PATH);
        Instances data = source.getDataSet();
        if(data.classIndex() == -1)
            data.setClassIndex(data.numAttributes() -1);
        AttributeSelection attributeSelection = new AttributeSelection();
        CfsSubsetEval eval = new CfsSubsetEval();
        GreedyStepwise search = new GreedyStepwise();
        search.setSearchBackwards(true);
        attributeSelection.setEvaluator(eval);
        attributeSelection.setSearch(search);
        attributeSelection.SelectAttributes(data);
        int[] indices = attributeSelection.selectedAttributes();
        pln(Utils.arrayToString(indices));

    }

4.其他

4.1 数据库表操作

@Test
    public void testSaveCSV() throws Exception {
        DatabaseLoader loader = new DatabaseLoader();
        loader.setUrl(SqlUtil.URL);
        loader.setUser(SqlUtil.USER);
        loader.setPassword(SqlUtil.PASSWORD);
        loader.setQuery("select question from question");
        Instances data1 = loader.getDataSet();
        if (data1.classIndex() == -1)
            data1.setClassIndex(data1.numAttributes() - 1);
        System.out.println(data1);

        CSVSaver saver = new CSVSaver();
        saver.setInstances(data1);
        saver.setFile(new File("data/weka/baidubook-csvsaver.csv"));
        saver.writeBatch();

    }

4.2 过滤器

过滤

@Test
    public void testFilter() throws Exception {
        Instances instances = ConverterUtils.DataSource.read("data/weka/houses.arff");
        instances.setClassIndex(instances.numAttributes() - 1);
        System.out.println(instances);
        String[] options = new String[2];
        options[0]  = "-R";
        options[1] = "1";
        Remove remove = new Remove();
        remove.setOptions(options);
        remove.setInputFormat(instances);
        Instances newData = Filter.useFilter(instances,remove);
        System.out.println(newData);
    }

过滤并分类

@Test
    public void testFilterOnTheFly() throws Exception {
        Instances instances = ConverterUtils.DataSource.read("data/weka/weather.nominal.arff");
        instances.setClassIndex(instances.numAttributes() - 1);
        System.out.println(instances);
        Remove remove = new Remove();
        remove.setAttributeIndices("1");
        // classify
        J48 j48 = new J48();
        j48.setUnpruned(true);
        FilteredClassifier fc = new FilteredClassifier();
        fc.setFilter(remove);
        fc.setClassifier(j48);
        fc.buildClassifier(instances);
        System.out.println(fc);
        for(int i =0 ;i
            double pred = fc.classifyInstance(instances.instance(i));
            System.out.print(instances.classAttribute().value((int)instances.instance(i).classValue()));
            System.out.println(instances.classAttribute().value((int) pred));
        }

        remove.setInputFormat(instances);
        Instances newData = Filter.useFilter(instances,remove);
        System.out.println(newData);
    }

你可能感兴趣的:(DataMining,ML)