Java 机器学习库Smile实战(一)SVM

       本文不会介绍SVM的基本原理,如果想了解SVM基本原理,请参阅相关书籍。

1. 二分类

       Smile 库的SVM类是一个泛型类型,默认情况下进行二分类,选择参数为核函数类型和惩罚项参数。

import smile.classification.SVM;
import smile.math.kernel.GaussianKernel;

 double gamma = 1.0;
 double C = 1.0;

 //通过某种方式获取训练数据及其类标
 double[][] data = ...
 int[] label = ...

 SVM<double[]> svm = new SVM<double[]>(new GaussianKernel(gamma), C);
 svm.learn(data, label); //训练模型
 svm.finish();

      接下来就可以对未知数据进行分类:

  //获取测试数据
  double[][] testData = ...
  int[] result = new int[testData.length];
  for(int i=0; i < testData.length; i++){
      result[i] = svm.predict(testData);
  }

2. 多分类

       接下来是我利用SVM对iris数据集进行分类的程序。首先我们将iris数据保存iris.txt文件,如下结构:

5.1 3.5 1.4 0.2 0
4.9 3   1.4 0.2 0
...

       每一行代表一个测试数据项,前4列是属性向量,最后一列是类标(在Smile中类标不能为负数,并且只能是从0开始的正整数,所以上述类标为:0、1、2)。检测的完整的源代码如下:

import smile.classification.SVM;
import smile.math.kernel.GaussianKernel;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/** * Created by zhanghuayan on 2017/1/16. */
public class ClassificationTest {

    public static void main(String[] args) throws Exception {

        List<List<Double>> datas = new ArrayList<List<Double>>();
        List<Double> data = new ArrayList<Double>();
        List<Integer> labels = new ArrayList<Integer>();

        String line;
        List<String> lines;
        File file = new File("iris.txt");
        BufferedReader reader = new BufferedReader(new FileReader(file));
        while ((line = reader.readLine()) != null) {
            lines = Arrays.asList(line.trim().split("\t"));
            for (int i = 0; i < lines.size() - 1; i++) {
                data.add(Double.parseDouble(lines.get(i)));
            }
            labels.add(Integer.parseInt(lines.get(lines.size() - 1)));

            datas.add(data);
            data = new ArrayList<Double>();

        }

        //转换label
        int[] label = new int[labels.size()];
        for (int i = 0; i < label.length; i++) {
            label[i] = labels.get(i);
        }

        //转换属性
        int rows = datas.size();
        int cols = datas.get(0).size();
        double[][] srcData = new double[rows][cols];
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                srcData[i][j] = datas.get(i).get(j);
            }
        }

        SVM<double[]> svm = new SVM<double[]>(new GaussianKernel(1.0), 1.0, 3, SVM.Multiclass.ONE_VS_ALL);
        svm.learn(srcData, label);
        svm.finish();

        double right = 0;
        for (int i = 0; i < srcData.length; i++) {
            int tag = svm.predict(srcData[i]);
            if (tag == label[i]) {
                right += 1;
            }
        }
        right = right / srcData.length;

        System.out.println("Accrurate: " + right * 100 + "%");

    }
}

SVM初始化的四个参数分别为

  1. 核函数类型;
  2. 惩罚项参数;
  3. 类标种类数;
  4. 多分类策略。

你可能感兴趣的:(机器学习,SVM,Smile库)