本文是我学习《数据挖掘与机器学习–WEKA应用技术与实践》的笔记。该书电子版的链接是:http://download.csdn.net/detail/fhb292262794/8759397
前一篇博文总结了用Weka演示机器学习的算法处理,主要是通过Weka3.8的客户端软件操作。
本文通过Java API调用来处理,这样就可以在编程中应用Weka的机器学习算法处理数据。
本书的实例是用weka3.7,我下载使用的是最新版的weka3.8,更新代码以适应了weka3.8后,整理记录如下。
预测房价
房价数据:
@RELATION house
@ATTRIBUTE houseSize NUMERIC
@ATTRIBUTE lotSize NUMERIC
@ATTRIBUTE bedrooms NUMERIC
@ATTRIBUTE granite NUMERIC
@ATTRIBUTE bathroom NUMERIC
@ATTRIBUTE sellingPrice NUMERIC
@DATA
3529,9191,6,0,0,205000
3247,10061,5,1,1,224900
4032,10150,5,0,1,197900
2397,14156,4,1,0,189900
2200,9600,4,0,1,195000
3536,19994,6,1,1,325000
2983,9365,5,0,1,230000
需求:根据该地区附近的房子信息及房价出售价格,预测新的房子的售价。该房子信息是:houseSize:3198;lotSize:9669;bedrooms:5;granite:3;bathroom:1;请预测房价。
由需求可知,需要处理的逻辑是:
1.加载房价数据。
2.设置属性信息。
3.构建分类器并计算系数。
4.使用回归系数预测未知房价。
代码如下:
public static final String WEKA_PATH = "data/weka/";
public static final String WEATHER_NOMINAL_PATH = "data/weka/weather.nominal.arff";
public static final String WEATHER_NUMERIC_PATH = "data/weka/weather.numeric.arff";
public static final String SEGMENT_CHALLENGE_PATH = "data/weka/segment-challenge.arff";
public static final String SEGMENT_TEST_PATH = "data/weka/segment-test.arff";
public static final String IONOSPHERE_PATH = "data/weka/ionosphere.arff";
public static void pln(String str) {
System.out.println(str);
}
@Test
public void testLinearRegression() throws Exception {
Instances dataset = ConverterUtils.DataSource.read(WEKA_PATH + "houses.arff");
dataset.setClassIndex(dataset.numAttributes() - 1);
LinearRegression linearRegression = new LinearRegression();
try {
linearRegression.buildClassifier(dataset);
} catch (Exception e) {
e.printStackTrace();
}
double[] coef = linearRegression.coefficients();
double myHouseValue = (coef[0] * 3198) +
(coef[1] * 9669) +
(coef[2] * 5) +
(coef[3] * 3) +
(coef[4] * 1) +
coef[6];
System.out.println(myHouseValue);
}
代码:
@Test
public void testRandomForestClassifier() throws Exception {
ArffLoader loader = new ArffLoader();
loader.setFile(new File(WEKA_PATH + "segment-challenge.arff"));
Instances instances = loader.getDataSet();
instances.setClassIndex(instances.numAttributes() - 1);
System.out.println(instances);
System.out.println("------------");
RandomForest rf = new RandomForest();
rf.buildClassifier(instances);
System.out.println(rf);
}
// 元分类器
@Test
public void testMetaClassifier() throws Exception {
Instances data = ConverterUtils.DataSource.read(WEATHER_NUMERIC_PATH);
if (data.classIndex() == -1)
data.setClassIndex(data.numAttributes() - 1);
AttributeSelectedClassifier classifier = new AttributeSelectedClassifier();
CfsSubsetEval eval = new CfsSubsetEval();
GreedyStepwise stepwise = new GreedyStepwise();
stepwise.setSearchBackwards(true);
J48 base = new J48();
classifier.setClassifier(base);
classifier.setEvaluator(eval);
classifier.setSearch(stepwise);
Evaluation evaluation = new Evaluation(data);
evaluation.crossValidateModel(classifier, data, 10, new Random(1234));
pln(evaluation.toSummaryString());
}
代码:
/**
* 利用训练集预测测试集的分类,批量处理
*/
@Test
public void testOutputClassDistribution() throws Exception {
ArffLoader loader = new ArffLoader();
loader.setFile(new File(SEGMENT_CHALLENGE_PATH));
Instances train = loader.getDataSet();
train.setClassIndex(train.numAttributes() - 1);
ArffLoader loader1 = new ArffLoader();
loader1.setFile(new File(SEGMENT_TEST_PATH));
Instances test = loader1.getDataSet();
test.setClassIndex(test.numAttributes() - 1);
J48 classifier = new J48();
classifier.buildClassifier(train);
System.out.println("num\t-\tfact\t-\tpred\t-\terr\t-\tdistribution");
for (int i = 0; i < test.numInstances(); i++) {
double pred = classifier.classifyInstance(test.instance(i));
double[] dist = classifier.distributionForInstance(test.instance(i));
StringBuilder sb = new StringBuilder();
sb.append(i + 1)
.append(" - ")
.append(test.instance(i).toString(test.classIndex()))
.append(" - ")
.append(test.classAttribute().value((int) pred))
.append(" - ");
if (pred != test.instance(i).classValue())
sb.append("yes");
else
sb.append("no");
sb.append(" - ");
sb.append(Utils.arrayToString(dist));
System.out.println(sb.toString());
}
}
这里指定的是J48,是决策树分类器,可以用其他更好的分类器替代,请比较效果选用分类器。
代码:
// 交叉验证并预测
@Test
public void testOnceCVAndPrediction() throws Exception {
Instances data = ConverterUtils.DataSource.read(IONOSPHERE_PATH);
data.setClassIndex(data.numAttributes() - 1);
Classifier classifier = new J48();
int seed = 1234;
int folds = 10;
Debug.Random random = new Debug.Random(seed);
Instances newData = new Instances(data);
newData.randomize(random);
if (newData.classAttribute().isNominal())
newData.stratify(folds);
// 执行交叉验证,并添加预测
Instances predictedData = null;
Evaluation eval = new Evaluation(newData);
for (int i = 0; i < folds; i++) {
Instances train = newData.trainCV(folds, i);
Instances test = newData.testCV(folds, i);
Classifier clsCopy = AbstractClassifier.makeCopy(classifier);
clsCopy.buildClassifier(train);
eval.evaluateModel(clsCopy, test);
// add prediction
AddClassification filter = new AddClassification();
filter.setClassifier(classifier);
filter.setOutputClassification(true);
filter.setOutputDistribution(true);
filter.setOutputErrorFlag(true);
filter.setInputFormat(train);
Filter.useFilter(train, filter);
Instances pred = Filter.useFilter(test, filter);
if (predictedData == null)
predictedData = new Instances(pred, 0);
for (int j = 0; j < pred.numInstances(); j++)
predictedData.add(pred.instance(j));
}
pln("classifier:" + classifier.getClass().getName() + " " + Utils.joinOptions(((OptionHandler) classifier).getOptions()));
pln("data:" + data.relationName());
pln("seed:" + seed);
pln(eval.toSummaryString("=== " + folds + " test ===", false));
// write data
ConverterUtils.DataSink.write(WEKA_PATH + "predictions.arff", predictedData);
}
@Test
public void testEM() throws Exception {
Instances instances = ConverterUtils.DataSource.read(WEKA_PATH + "contact-lenses.arff");
EM cluster = new EM();
cluster.setOptions(new String[]{"-I", "100"});
cluster.buildClusterer(instances);
pln(cluster.toString());
}
// 评估聚类器的方式 3种
@Test
public void testEvaluation() throws Exception {
String filePath = WEKA_PATH + "contact-lenses.arff";
Instances instances = ConverterUtils.DataSource.read(filePath);
// 第1种
String[] options = new String[]{"-t", filePath};
String output = ClusterEvaluation.evaluateClusterer(new EM(), options);
pln(output);
// 第2种
DensityBasedClusterer dbc = new EM();
dbc.buildClusterer(instances);
ClusterEvaluation clusterEvaluation = new ClusterEvaluation();
clusterEvaluation.setClusterer(dbc);
clusterEvaluation.evaluateClusterer(new Instances(instances));
pln(clusterEvaluation.clusterResultsToString());
// 第3种
// 基于密度的聚类器交叉验证
DensityBasedClusterer newdbc = new EM();
double logLikelyhood = ClusterEvaluation.crossValidateModel(newdbc, instances, 10, instances.getRandomNumberGenerator(1234));
pln("logLikelyhood: " + logLikelyhood);
}
@Test
public void testClassesToClusters() throws Exception {
String filePath = WEKA_PATH + "contact-lenses.arff";
Instances data = ConverterUtils.DataSource.read(filePath);
data.setClassIndex(data.numAttributes() - 1);
Remove remove = new Remove();
remove.setAttributeIndices("" + (data.classIndex() + 1));
remove.setInputFormat(data);
Instances dataCluster = Filter.useFilter(data, remove);
Clusterer cluster = new EM();
cluster.buildClusterer(dataCluster);
ClusterEvaluation eval = new ClusterEvaluation();
eval.setClusterer(cluster);
eval.evaluateClusterer(data);
pln(eval.clusterResultsToString());
}
@Test
public void testOutputClusterDistribution() throws Exception {
Instances train = ConverterUtils.DataSource.read(SEGMENT_CHALLENGE_PATH);
Instances test = ConverterUtils.DataSource.read(SEGMENT_TEST_PATH);
if (!train.equalHeaders(test))
throw new Exception("train data and test data not the same headers.");
EM clusterer = new EM();
clusterer.buildClusterer(train);
pln("id - cluster - distribution");
for (int i = 0; i < test.numInstances(); i++) {
int cluster = clusterer.clusterInstance(test.instance(i));
double[] dists = clusterer.distributionForInstance(test.instance(i));
StringBuilder sb = new StringBuilder();
sb.append(i + 1).append(" - ").append(cluster).append(" - ").append(Utils.arrayToString(dists));
pln(sb.toString());
}
}
应用CfsSubsetEval及GreedyStepwise处理:
// 底层API属性选择
@Test
public void testUseLowApi() throws Exception {
ConverterUtils.DataSource source = new ConverterUtils.DataSource(WEATHER_NOMINAL_PATH);
Instances data = source.getDataSet();
if(data.classIndex() == -1)
data.setClassIndex(data.numAttributes() -1);
AttributeSelection attributeSelection = new AttributeSelection();
CfsSubsetEval eval = new CfsSubsetEval();
GreedyStepwise search = new GreedyStepwise();
search.setSearchBackwards(true);
attributeSelection.setEvaluator(eval);
attributeSelection.setSearch(search);
attributeSelection.SelectAttributes(data);
int[] indices = attributeSelection.selectedAttributes();
pln(Utils.arrayToString(indices));
}
@Test
public void testSaveCSV() throws Exception {
DatabaseLoader loader = new DatabaseLoader();
loader.setUrl(SqlUtil.URL);
loader.setUser(SqlUtil.USER);
loader.setPassword(SqlUtil.PASSWORD);
loader.setQuery("select question from question");
Instances data1 = loader.getDataSet();
if (data1.classIndex() == -1)
data1.setClassIndex(data1.numAttributes() - 1);
System.out.println(data1);
CSVSaver saver = new CSVSaver();
saver.setInstances(data1);
saver.setFile(new File("data/weka/baidubook-csvsaver.csv"));
saver.writeBatch();
}
过滤
@Test
public void testFilter() throws Exception {
Instances instances = ConverterUtils.DataSource.read("data/weka/houses.arff");
instances.setClassIndex(instances.numAttributes() - 1);
System.out.println(instances);
String[] options = new String[2];
options[0] = "-R";
options[1] = "1";
Remove remove = new Remove();
remove.setOptions(options);
remove.setInputFormat(instances);
Instances newData = Filter.useFilter(instances,remove);
System.out.println(newData);
}
过滤并分类
@Test
public void testFilterOnTheFly() throws Exception {
Instances instances = ConverterUtils.DataSource.read("data/weka/weather.nominal.arff");
instances.setClassIndex(instances.numAttributes() - 1);
System.out.println(instances);
Remove remove = new Remove();
remove.setAttributeIndices("1");
// classify
J48 j48 = new J48();
j48.setUnpruned(true);
FilteredClassifier fc = new FilteredClassifier();
fc.setFilter(remove);
fc.setClassifier(j48);
fc.buildClassifier(instances);
System.out.println(fc);
for(int i =0 ;i
double pred = fc.classifyInstance(instances.instance(i));
System.out.print(instances.classAttribute().value((int)instances.instance(i).classValue()));
System.out.println(instances.classAttribute().value((int) pred));
}
remove.setInputFormat(instances);
Instances newData = Filter.useFilter(instances,remove);
System.out.println(newData);
}