基于weka手工实现支持向量机smo算法

关于svm机器学习模型,我主要学习的是周志华老师的西瓜书(《机器学习》);

但是西瓜书中对于参数优化(即:Sequential Minimal Optimization,smo算法)部分讲解的十分简略,看起来不太好懂。因此这一部分参考的是John C. Platt 1998年发表的论文:Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines

值得注意的是,S.S. Keerthi在2001年又发表了一篇名为Imrovements to Platt’s SMO Algorithm for SVM Classifier Design的文章,在这篇文章中,它改进了原版本smo的收敛条件,并融入了许多缓存机制,好处是求解速度更快了,但理解起来较为晦涩。

因为smo数学原理较强,处于学习考虑,我这里的实现参考的的是John C. Platt 1998年发表的论文。

一、支持向量机(SVM)模型

支持向量机就是想找到一个间隔最大的超平面,将正负两种样本分割开来,进而实现分类的一个模型。

基于weka手工实现支持向量机smo算法_第1张图片

支持向量机寻找到的超平面可以用如下公式来表示:

w T x + b = 0 w^Tx+b=0 wTx+b=0

如果输入x,结果大于0,就为正例;
如果输入x,结果小于0,就为负例,进而实现分类任务。

根据周志华西瓜书(《机器学习》p121-123)中的公式推导,我们要想寻找到参数的最优解,最终的优化目标如下:

min ⁡ w , b 1 2 ∣ ∣ w ∣ ∣ 2 s . t . y i ( w T x i + b ) ≥ 1 , i = 1 , 2 , . . . , m \min_{w,b}\frac{1}{2}||w||^2\\ s.t. \quad y_i(w^Tx_i+b) \ge 1,\quad i=1,2,...,m w,bmin21∣∣w2s.t.yi(wTxi+b)1,i=1,2,...,m

上面这个最优化目标对应的 w w w b b b 就是我们最终想要的结果。1

二、序列最小优化算法(SMO)

而优化上述模型所采用的优化算法一种就是二次规划,采用线程的优化包进行求解,但是当样本量非常大的情况下,约束目标的数量也会非常大,会出现维度爆炸的问题,而相比之下,SMO算法就可以很好地解决这个问题。

在学习SMO算法的时候,我首先阅读的是西瓜书上的相关内容,但是十分晦涩,读完后一头雾水。

然后我又找同学借了本李航老师的《统计机器学习》进行阅读,读了几遍之后感觉虽然有了一个大体的思路,但是具体如何编码实现呢?比如如何判定一个样例是否满足KKT条件?还是不太会。

直到最后,被逼无奈之下去看了John C. Platt 1998年发表的原版论文(Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines),看完后真的有种柳暗花明又一村的感觉,感觉比上述两位老师写的教材还要好懂一些,所以十分建议阅读。最重要的是,在文章的末尾,John C. Platt前辈还提供了一段C语言的伪代码,对照着伪代码以及文章中的公式,再回过头来写svm的模型就很容易了。2

三、基于weka实现的SMO

因为svm需要经过一个sigmoid函数类似的指数类型的变换以及核函数的处理(高斯核、拉普拉斯核都会涉及到指数),因此,其值大小是非常重要的。如果svm输出的值为几百或者几千,那么经过指数后,直接就会变为无穷大inf或者Nan。

我在编码的完成后,自己写的svm的性能总是不好,找了很久的原因,最终定位在了这个bug上。

package weka.classifiers.myf;

import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Standardize;

import java.io.Serializable;
import java.util.HashSet;
import java.util.Set;

enum KernelType {
    KERNEL_LINEAR, KERNEL_POLYNOMIAL, KERNEL_RBF, KERNEL_SIGMOID
}

/**
 * @author YFMan
 * @Description 自定义的 SMO 分类器
 * @Date 2023/6/12 15:45
 */
public class mySMO extends Classifier {

    // 二元支持向量机
    public static class BinarySMO implements Serializable {

        // alpha
        protected double[] m_alpha;

        // bias
        protected double m_b;

        // 训练集
        protected Instances m_train;

        // 权重向量
        protected double[] m_weights;

        // 训练数据的类别标签
        protected double[] m_class;

        // 支持向量集合 {i: 0 < m_alpha[i] < C}
        protected Set<Integer> m_supportVectors;

        // 惩罚因子C,超参数
        protected double m_C = 10.0;

        // 容忍参数
        protected double m_toleranceParameter = 1.0e-3;

        // 四舍五入的容忍参数
        protected double m_epsilon = 1.0e-12;

        // 最大迭代次数
        protected int m_maxIterations = 100000;

        // 当前已经执行的迭代次数
        protected int m_numIterations = 0;


        // 定义核函数枚举类型
        protected KernelType m_kernelType = KernelType.KERNEL_LINEAR;

        // 定义多项式核函数的参数
        protected double m_exponent = 2.0;

        // 定义 高斯核 和 拉普拉斯 核函数的参数
        protected double m_gamma = 1.0;

        // 定义 SIGMOID 核函数的参数 beta
        protected double m_sigmoidBeta = 1.0;


        // 定义 sigmoid 核函数的参数 theta
        protected double m_sigmoidTheta = -1.0;

        // 程序精度误差
        protected double m_Del = 1000 * Double.MIN_VALUE;

        /*
         * @Author YFMan
         * @Description // 构建分类器
         * @Date 2023/6/12 22:03
         * @Param [instances 训练数据集, cl1 正类, cl2 负类]
         * @return void
         **/
        protected void buildClassifier(Instances instances, int cl1, int cl2) throws Exception {
            // 初始化 alpha
            m_alpha = new double[instances.numInstances()];

            // 初始化 bias
            m_b = 0;

            // 初始化训练集
            m_train = instances;

            // 初始化权重向量
            m_weights = new double[instances.numAttributes() - 1];

            // 初始化支持向量集合
            m_supportVectors = new HashSet<Integer>();

            // 初始化 m_class
            m_class = new double[instances.numInstances()];

            // 将标签转换为 -1 和 1
            for (int i = 0; i < m_class.length; i++) {
                // 如果实例的类别标签为负类,则将其转换为 -1
                if (instances.instance(i).classValue() == cl1) {
                    m_class[i] = -1;
                } else {
                    m_class[i] = 1;
                }
            }

            int numChanged = 0;         // 记录改变的拉格朗日乘子的个数
            boolean examineAll = true;  // 是否检查所有的实例

            while ((numChanged > 0 || examineAll) && (m_numIterations < m_maxIterations)) {
                numChanged = 0;
                if (examineAll) {
                    // loop over all training examples
                    for (int i = 0; i < m_train.numInstances(); i++) {
                        numChanged += examineExample(i);
                    }
                } else {
                    // loop over examples where alpha is not 0 & not C
                    for (int i = 0; i < m_train.numInstances(); i++) {
                        if ((m_alpha[i] != 0) && (m_alpha[i] != m_C)) {
                            numChanged += examineExample(i);
                        }
                    }
                }

                if (examineAll) {
                    examineAll = false;
                } else if (numChanged == 0) {
                    examineAll = true;
                }

                m_numIterations++;
            }
        }

        /*
         * @Author YFMan
         * @Description // 计算 SVM 的输出
         * @Date 2023/6/14 19:26
         * @Param [index, inst]
         * @return double
         **/
        public double SVMOutput(Instance instance) throws Exception {
            double result = 0;

            if (m_kernelType == KernelType.KERNEL_LINEAR) {
                for (int i = 0; i < m_weights.length; i++) {
                    result += m_weights[i] * instance.value(i);
                }
            } else {
                // 非线性核函数 计算 SVM 的输出
                for (int i = 0; i < m_train.numInstances(); i++) {
                    // 只有支持向量的拉格朗日乘子才会大于 0 且两个向量不重合
                    if (m_alpha[i] > 0) {
                        result += m_alpha[i] * m_class[i] * kernelFunction(m_train.instance(i), instance);
                    }
                }
            }

            result -= m_b;
            return result;
        }

        /*
         * @Author YFMan
         * @Description // 根据 i2 选择第二个变量,并且更新拉格朗日乘子
         * @Date 2023/6/14 19:58
         * @Param [i2]
         * @return int
         **/
        protected int examineExample(int i2) throws Exception {
            double y2 = m_class[i2];
            double alph2 = m_alpha[i2];
            double E2 = SVMOutput(m_train.instance(i2)) - y2;
            double r2 = E2 * y2;

            if (r2 < -m_toleranceParameter && alph2 < m_C || r2 > m_toleranceParameter && alph2 > 0) {
                // 第一种情况:违反KKT条件
                // 选择第二个变量
                if (m_supportVectors.size() > 1) {
                    // 选择第二个变量
                    int i1 = -1;
                    double max = 0;
                    for (Integer index : m_supportVectors) {
                        double E1 = SVMOutput(m_train.instance(index)) - m_class[index];
                        double temp = Math.abs(E1 - E2);
                        if (temp > max) {
                            max = temp;
                            i1 = index;
                        }
                    }
                    // 如果找到了第二个变量
                    if (i1 >= 0) {
                        if (takeStep(i1, i2) == 1) {
                            return 1;
                        }
                    }
                }
                // 第二种情况:没有选择第二个变量
                for (int index : m_supportVectors) {
                    if (takeStep(index, i2) == 1) {
                        return 1;
                    }
                }
                // 第三种情况:没有选择支持向量
                for (int index = 0; index < m_train.numInstances(); index++) {
                    if (takeStep(index, i2) == 1) {
                        return 1;
                    }
                }
            }
            return 0;
        }

        /*
         * @Author YFMan
         * @Description // 根据 i1 和 i2 更新拉格朗日乘子
         * @Date 2023/6/14 19:59
         * @Param [i1, i2]
         * @return int
         **/
        protected int takeStep(int i1, int i2) throws Exception {
            if (i1 == i2) {
                return 0;
            }

            double alph1 = m_alpha[i1];
            double alph2 = m_alpha[i2];
            double y1 = m_class[i1];
            double y2 = m_class[i2];
            double E1 = SVMOutput(m_train.instance(i1)) - y1;
            double E2 = SVMOutput(m_train.instance(i2)) - y2;
            double s = y1 * y2;

            double L = 0;
            double H = 0;
            if (y1 != y2) {
                L = Math.max(0, alph2 - alph1);
                H = Math.min(m_C, m_C + alph2 - alph1);
            } else {
                L = Math.max(0, alph2 + alph1 - m_C);
                H = Math.min(m_C, alph2 + alph1);
            }

            if (L == H) {
                return 0;
            }

            double k11 = kernelFunction(m_train.instance(i1), m_train.instance(i1));
            double k12 = kernelFunction(m_train.instance(i1), m_train.instance(i2));
            double k22 = kernelFunction(m_train.instance(i2), m_train.instance(i2));
            double eta = k11 + k22 - 2 * k12;

            double a1 = 0;
            double a2 = 0;

            if (eta > 0) {
                a2 = alph2 + y2 * (E1 - E2) / eta;
                if (a2 < L) {
                    a2 = L;
                } else if (a2 > H) {
                    a2 = H;
                }
            } else {
                double f1 = y1 * (E1 + m_b) - alph1 * k11 - s * alph2 * k12;
                double f2 = y2 * (E2 + m_b) - s * alph1 * k12 - alph2 * k22;
                double L1 = alph1 + s * (alph2 - L);
                double H1 = alph1 + s * (alph2 - H);

                // objective function at a2=L
                double Lobj = L1 * f1 + L * f2 + 0.5 * L1 * L1 * k11 + 0.5 * L * L * k22 + s * L * L1 * k12;
                // objective function at a2=H
                double Hobj = H1 * f1 + H * f2 + 0.5 * H1 * H1 * k11 + 0.5 * H * H * k22 + s * H * H1 * k12;

                if (Lobj > Hobj + m_epsilon) {
                    a2 = L;
                } else if (Lobj < Hobj - m_epsilon) {
                    a2 = H;
                } else {
                    a2 = alph2;
                }
            }

            if (Math.abs(a2 - alph2) < m_epsilon * (a2 + alph2 + m_epsilon)) {
                return 0;
            }

            if (a2 > m_C - m_Del * m_C) // m_Del = 1000 *
                // Double.MIN_VALUE,在精度误差上做了一点处理
                a2 = m_C;
            else if (a2 <= m_Del * m_C)
                a2 = 0;

            a1 = alph1 + s * (alph2 - a2);

            // Update threshold to reflect change in Lagrange multipliers
            double b1 = E1 + y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12 + m_b;
            double b2 = E2 + y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22 + m_b;

            if ((0 < a1 && a1 < m_C) && (0 < a2 && a2 < m_C)) {
                m_b = (b1 + b2) / 2;
            } else if (0 < a1 && a1 < m_C) {
                m_b = b1;
            } else if (0 < a2 && a2 < m_C) {
                m_b = b2;
            }

            // Update weight vector to reflect change in a1 & a2, if linear SVM
            if (m_kernelType == KernelType.KERNEL_LINEAR) {
                int column = 0;
                for (int i = 0; i < m_train.numAttributes(); i++) {
                    if (i != m_train.classIndex()) {
                        m_weights[column] += y1 * (a1 - alph1) * m_train.instance(i1).value(i) + y2 * (a2 - alph2) * m_train.instance(i2).value(i);
                        column++;
                    }
                }
            }

            m_alpha[i1] = a1;
            m_alpha[i2] = a2;

            return 1;
        }

        /*
         * @Author YFMan
         * @Description // 核函数
         * @Date 2023/6/14 19:29
         * @Param [i1, i2]
         * @return double
         **/
        protected double kernelFunction(Instance instance1, Instance instance2) throws Exception {
            switch (m_kernelType) {
                case KERNEL_LINEAR:
                    return linearKernel(instance1, instance2);
                case KERNEL_POLYNOMIAL:
                    return polynomialKernel(instance1, instance2);
                case KERNEL_RBF:
                    return rbfKernel(instance1, instance2);
                case KERNEL_SIGMOID:
                    return sigmoidKernel(instance1, instance2);
                default:
                    throw new Exception("Invalid kernel type.");
            }
        }

        /*
         * @Author YFMan
         * @Description // 线性核函数
         * @Date 2023/6/14 20:33
         * @Param [instance1, instance2]
         * @return double
         **/
        protected double linearKernel(Instance instance1, Instance instance2) {
            double result = 0;
            for (int i = 0; i < m_train.numAttributes() - 1; i++) {
                result += instance1.value(i) * instance2.value(i);
            }
            return result;
        }

        protected double polynomialKernel(Instance instance1, Instance instance2) {
            double result = 0;
            for (int i = 0; i < m_train.numAttributes() - 1; i++) {
                result += instance1.value(i) * instance2.value(i);
            }
            return Math.pow(result + m_gamma, m_exponent);
        }

        /*
         * @Author YFMan
         * @Description // 高斯核函数
         * @Date 2023/6/15 10:46
         * @Param [instance1, instance2]
         * @return double
         **/
        protected double rbfKernel(Instance instance1, Instance instance2) {
            double result = 0;
            for (int i = 0; i < m_train.numAttributes() - 1; i++) {
                result += Math.pow(instance1.value(i) - instance2.value(i), 2);
            }
            return Math.exp(-result / (2 * m_gamma * m_gamma));
        }

        /*
         * @Author YFMan
         * @Description // sigmoid 核函数
         * @Date 2023/6/15 10:47
         * @Param [instance1, instance2]
         * @return double
         **/
        protected double sigmoidKernel(Instance instance1, Instance instance2) {
            double result = 0;
            for (int i = 0; i < m_train.numAttributes() - 1; i++) {
                result += instance1.value(i) * instance2.value(i);
            }
            return Math.tanh(m_sigmoidBeta * result + m_sigmoidTheta);
        }

    }

    // 归一化数据的过滤器
    public static final int FILTER_NORMALIZE = 0;

    // 标准化数据的过滤器
    public static final int FILTER_STANDARDIZE = 1;

    // 不使用过滤器
    public static final int FILTER_NONE = 2;

    // 二元分类器
    protected BinarySMO m_classifier = null;

    // 是否使用过滤器
    protected int m_filterType = FILTER_NORMALIZE;

    // 用于标准化/归一化数据的过滤器
    protected Filter m_Filter = null;

    // 用于标准化数据的过滤器
    protected Filter m_StandardizeFilter = null;

    // 用于二值化数据的过滤器
    protected Filter m_NominalToBinary = null;

    /*
     * @Author YFMan
     * @Description // 构建分类器
     * @Date 2023/6/14 20:29
     * @Param [insts]
     * @return void
     **/
    public void buildClassifier(Instances insts) throws Exception {
        // 标准化数据
        m_StandardizeFilter = new Standardize();
        m_StandardizeFilter.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_StandardizeFilter);

        // 二值化数据
        m_NominalToBinary = new NominalToBinary();
        m_NominalToBinary.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_NominalToBinary);



        m_classifier = new BinarySMO();

        m_classifier.buildClassifier(insts, 0, 1);
    }

    /*
     * @Author YFMan
     * @Description // 分类实例
     * @Date 2023/6/14 20:43
     * @Param [inst]
     * @return double[]
     **/
    public double[] distributionForInstance(Instance inst) throws Exception {
        // 过滤实例
        m_StandardizeFilter.input(inst);
        inst = m_StandardizeFilter.output();

        m_NominalToBinary.input(inst);
        inst = m_NominalToBinary.output();

        double[] result = new double[2];

        double output = m_classifier.SVMOutput(inst);
        result[1] = 1.0 / (1.0 + Math.exp(-output));
        result[0] = 1.0 - result[1];

        return result;
    }

    /*
     * @Author YFMan
     * @Description // 主函数
     * @Date 2023/6/14 20:42
     * @Param [argv]
     * @return void
     **/
    public static void main(String[] argv) {
        runClassifier(new mySMO(), argv);
    }
}

四、感悟

支持向量机的优化部分smo数学原理很强,论文中的推导非常清晰,因此文中并没有对其过多解读,因为我解读的再细致,以我对smo的理解,也不可能有原作者好。

同时,虽然自己能将smo侥幸实现,但只能说按照文中公式及伪代码来理解一二,并不敢说对其理解有多深刻。直到现在,也依然有很多不明白的点。

对于计算机科学这门应用学科而言,数学永远是天花板,也许我们能侥幸的把它用起来,但如果真正的想要有所建树和理论创新,可能还要回归到数学吧。


  1. 《机器学习》周志华 ↩︎

  2. Platt J. Sequential minimal optimization: A fast algorithm for training support vector machines[J]. 1998. ↩︎

你可能感兴趣的:(机器学习,支持向量机,数据挖掘,算法,机器学习,人工智能)