不调包,用numpy实现svm,对鸢尾花进行二分类(smo算法)

import random

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_iris


class Supper_Vector_Machine:
    def __init__(self, dataset, target, C, toler, Iter_max):
        self.dataset = dataset
        self.target = target
        self.N, self.M = len(self.dataset), len(dataset[0])
        self.C = C
        self.toler = toler
        self.b = 0
        self.Alpha = np.zeros(self.N)
        '''最多遍历iter次,若都符合KKT条件,则说明基本符合alpha基本都符合'''
        self.iter_max = Iter_max
        self.w = np.zeros(self.M)

    def Fx(self, i):
        fxi = 0
        for k in range(self.N):
            fxi += self.Alpha[k] * self.target[k] * np.matmul(self.dataset[i], self.dataset[k].T)
        fxi += self.b
        return fxi

    def Kernel(self, i, j):
        result = np.matmul(self.dataset[i], self.dataset[j].T)
        return result

    def random_j(self, i):
        while True:
            j = random.choice(range(self.N))
            if j != i:
                return j

    def get_L_H(self, i, j):
        L, H = 0, 0
        if self.target[i] != self.target[j]:
            L = max([0, self.Alpha[j] - self.Alpha[i]])
            H = min([self.C, self.C + self.Alpha[j] - self.Alpha[i]])
        else:
            L = max([0, self.Alpha[j] + self.Alpha[i] - self.C])
            H = min([self.C, self.Alpha[i] + self.Alpha[j]])
        return L, H

    def filter(self, L, H, alpha_j):
        if alpha_j < L:
            alpha_j = L
        if alpha_j > H:
            alpha_j = H
        return alpha_j

    def SMO(self):
        iter = 0
        while iter < self.iter_max:
            change_num = 0
            for i in range(self.N):
                Fx_i = self.Fx(i)
                Ex_i = Fx_i - self.target[i]
                '''确定是否符合KKT条件,不符合就进行更新'''
                if self.target[i] * Ex_i < -self.toler and self.Alpha[i] < self.C or self.target[
                    i] * Ex_i > self.toler and self.Alpha[i] > 0:
                    j = self.random_j(i)
                    print('i:{},j:{}'.format(i, j))
                    Fx_j = self.Fx(j)
                    Ex_j = Fx_j - self.target[j]

                    alpha_i = self.Alpha[i]
                    alpha_j = self.Alpha[j]

                    L, H = self.get_L_H(i, j)
                    if L == H:
                        print('L == H')
                        continue
                    
                    eta = self.Kernel(i, i) + self.Kernel(j, j) - 2 * self.Kernel(i, j)
                    if eta <= 0:
                        print('eta <= 0')
                        continue
                    ''''更新a_j'''
                    self.Alpha[j] += self.target[j] * (Ex_i - Ex_j) / eta
                    self.Alpha[j] = self.filter(L, H, self.Alpha[j])

                    if abs(self.Alpha[j] - alpha_j) < 0.00001:
                        print('alpha够精确了')
                        continue
                    '''更新alpha[i]'''
                    self.Alpha[i] += self.target[i] * self.target[j] * (alpha_j - self.Alpha[j])
                    '''更新b'''
                    b1 = self.b - Ex_i - self.target[i] * self.Kernel(i, i) * (self.Alpha[i] - alpha_i) - self.target[
                        j] * self.Kernel(i, j) * (self.Alpha[j] - alpha_j)
                    b2 = self.b - Ex_j - self.target[i] * self.Kernel(i, j) * (self.Alpha[i] - alpha_i) - self.target[
                        j] * self.Kernel(j, j) * (self.Alpha[j] - alpha_j)

                    if 0 < self.Alpha[i] < self.C:
                        self.b = b1
                    elif 0 < self.Alpha[j] < self.C:
                        self.b = b2
                    else:
                        self.b = (b1 + b2) / 2.0

                    print(self.Alpha[i], self.Alpha[j])
                    change_num += 1
            if change_num == 0:
                iter += 1
            else:
                iter = 0
        for i in range(self.N):
            self.w += self.target[i] * self.Alpha[i] * self.dataset[i]

    def display(self):
        svm_point = []
        for i in range(100):
            if self.Alpha[i] > 0:
                print('第{}个是支持向量'.format(i), self.dataset[i], self.target[i])
                svm_point.append(i)
        x_point = np.array([i[0] for i in dataset])
        y_point = np.array([i[1] for i in dataset])
        x = np.linspace(4, 6, 5)
        y = -(self.w[0] * x + self.b) / self.w[1]
        plt.scatter(x_point[:50], y_point[:50], color='red')
        plt.scatter(x_point[-50:], y_point[-50:], color='blue')
        support_vector = np.array([dataset[i] for i in svm_point])
        plt.scatter(support_vector[:, 0], support_vector[:, 1], color='black')
        plt.plot(x, y)
        plt.show()


if __name__ == '__main__':
    dataset, target = load_iris()['data'][:100, :2], load_iris()['target'][:100]
    target = np.array([1 if i == 1 else -1 for i in target])
    model = Supper_Vector_Machine(dataset, target, 100, 0.01, 40)
    model.SMO()
    print(model.w)
    print(model.b)
    model.display()

快速理解SMO算法_哔哩哔哩_bilibili

b站这位老师讲的SMO非常详细,推荐大家看一下,不过i和j的选择并没有参考该视频中的启发式方法(用kkt条件选择的话,我这边敲了几天都比较失败),i和j的选择我打算使用全部遍历的方法虽然耗时比较长,但效果总体还行

不调包,用numpy实现svm,对鸢尾花进行二分类(smo算法)_第1张图片

这是效果图,看上去效果还行的 

你可能感兴趣的:(支持向量机,分类,算法)