模式识别——判别函数及几何分类法

前言

本篇博客对模式识别中的“判别函数与几何分类法”进行了讨论,并重点实现了最小平方误差法(Least Mean Square Error, LMSE)算法,也对不同的多类判别法、分段线性判别法进行了一定的实现,并通过随机生成的数据集加以测试。

项目源码及所使用的数据集参见:PR-EXPT2


基础知识

我们首先了解判别函数与几何分类法的基本知识,在此目中不追求科学语言的精确性,旨在用形式化的表达阐述其大致的思想。

判别函数与几何分类

模式识别——判别函数及几何分类法_第1张图片

上图呈现了典型的几何分类法的特点,其中左侧为线性判别法,右侧为非线性判别法,蓝色的直线与绿色的圆圈就称为判别函数。

判别函数

若将二维模式推广至n维,则线性判别函数的一般形式为:

其中,X为增广模式向量,W为增广权向量。

我们以两类情况为例,对于模式在w1w2中的两类情况,线性判别法的判别方式为:

d(X) = 0时,为不可判别情况。

至此,我们已经知道了用判别函数来进行几何分类法的整个过程,即:

  • 得到判别函数d(X):即求出增广权向量W
  • 用判别函数进行识别:对于一个待测模式X',我们能够用判别函数进行相应判别.

上述过程的第二步:“用判别函数进行识别”是非常简单的,只需要做两个矩阵的乘法运算。关键是第一步,我们该怎么得到判别函数呢,即该怎么求这个W呢?

可以很直观的想到,判别函数一定是通过训练得到的,即用已有的、已知其所属类别的模式,学习得到的。

不妨假设对已知的w_k类来说,这个类里面有X1, X2, ..., Xn这n个模式,那么我们所要求解的W,首先一定是要满足W * Xi > 0 (i = 1, 2, ..., n)这个不等式组的。除此之外,还需要满足什么额外条件呢?

而这些额外条件,恰恰是与我们的类别数有关的。

二分类

如果整个模式集只有两类,w1w2,假设这两个类的判别函数分别为d1(X)d2(X),那么如果一个模式X0属于w1类,就需要满足:

  • d1(X0) > 0,即这个模式在第一类里
  • d2(X0) < 0,即这个模式不在第二类里

那么再回到我们刚才讨论的“额外条件”问题:

假设w1类里有模式X1_1, X1_2, ..., X1_nw2类有模式X2_1, X2_2, ..., X2_n,那么对于判别函数d1的训练,我们首先需要:

d1(X1_i) > 0 (i = 1, 2, ..., n)

其次,还需要满足额外条件:d1(X2_i) < 0,即:

d1(- X2_i) > 0 (i = 1, 2, ..., n)

多分类

那么现在如果模式集有多类,额外条件又该如何确定呢?

在这里我们介绍两种方法:

  1. i_non_i 两分法,或者“正负类”法,即对于类wi来说,我们把所有的模式集分为两种:一种是属于wi的,另一种是不属于wi的。那么求解判别函数di(X)的不等式组为:

    di(Xi_j) > 0 (j = 1, 2, ..., n)

    di(- Xk_j) > 0 (j = 1, 2, ..., n; k为其他的模式类)

  2. i_j 两分法,即对于类wi与类wj来说,这就成了一个二分类的问题。因此,若假设总共有k个模式类,那么如果要把类wi与其他k-1个模式类区分开,就需要求解k-1个判别函数di_1, di_2, ..., di_k-1

    此时,到了”用判别函数进行识别”的时候,若模式X0属于wi类,就需要满足:di_j(X0) > 0,即对于以i打头的所有判别函数,其判别值均需为正。

线性分类算法

到了此处,我们已经明确了整个几何分类法的过程,也知道了在训练判别函数的时候,所要满足的不等式组是什么。那么这个不等式组又该如何求解呢?

这就涉及到了不同的线性分类算法。

虽然这些算法在形式上看起来十分复杂,但我们只需要记住一点:这些算法其实就是在解一个不等式方程组。

我们可以发散地想,这些算法甚至是独立于“模式识别”而存在的。假如你在线性代数的课堂上遇到了一道解不等式方程组的问题,就可以通过这些算法来解决它。其之所以应用在模式识别的领域上,只不过因为,模式识别的几何分类法需要解不等式方程组(即:求判别函数)。

下面我们列举了几种经典的线性分类算法。

  • 感知器算法
  • 梯度法
  • 最小平方误差算法

在实际应用中,我们应该明确这三种算法之间的关系:

  • 感知器算法是梯度法的一种特例。
  • 最小平方误差算法利用了梯度法。
  • 感知器算法与梯度算法在迭代运算时,算法有可能始终不收敛,因此在程序运行的过程中,我们无法得知,这究竟是因为迭代次数还不够,还是因为数据集本身的就线性不可分。而最小平方误差法改善了这一困境。

因此我们可以看到最小平方误差法的优势所在,本篇博客也主要实现了最小平方误差法。

分段线性判别法

在此我们简要介绍分段线性判别法的思想。

  • 分段线性判别法首先对已知类进行划分,将其分为父类和子类。
  • 我们需要训练出同一父类的子类之间的判别函数。
  • 对于某个待测模式,我们先用同一父类的各子类判别函数进行计算,并用效果最好的子类作为父类的判别函数。
  • 之后,我们再通过父类的判别函数,将待测模式划分到某一父类。

方案设计

实验要求

1、请实现以下功能之一
1)编写程序,使用感知器算法、梯度法和最小平方误差法中的一种算法,
训练线性判别函数并进行数据分类,分类时采用多分类情况的两种。

2)编写程序,使用感知器算法、梯度法和最小平方误差法中的两种算法,
训练线性判别函数并进行数据分类,分类时采用多分类情况的一种。

样本数据请自拟并不能少于30个数据样本,类别不少于5类。

其中部分样本用于训练判别函数,部分用于验证,并计算出正确率

2、请编写程序实现分段线性判别法进行数据分类。

样本数据请自拟并不能少于30个数据样本,类别不少于3类,每类至少有2个以上子类。

设计步骤

实验的实现共分为以下几个过程:

  1. 数据集的获取。在本实验中,随机生成了模式数据集。训练样本共分为6类,每类有8个模式;测试样本共分为6类,每类2个模式。

  2. 最小平方误差法(Least Mean Square Error, LMSE)算法的实现:

    最小平方误差算法(Least Mean Square Error, LMSE)
    (1)输入参数:规范化增广样本矩阵x
    (2)求X的伪逆矩阵 x# = (xT x)-1 xT
    (3)设置初值c和B(1)
    (4)计算e(k),并进行可分性判别
    (5)计算W(k+1)、B(k+1)
  3. 采用i_non_i 分类法训练判别函数,并评测识别正确率

  4. 采用i_j 分类法训练判别函数,并评测识别正确率

  5. 实现分段线性判别法,训练子类的判别函数,并评测识别正确率

实验环境

操作系统:MacOS X

开发语言:Python 2.7


具体实现

数据集生成

如上图所示,在[0, 50]的范围内随机生成模式集,并具体划分在0, 1, 2, 3, 4, 5 这6个类中。

具体实现如下:

# encoding:utf-8
# 生成数据集

import numpy as np


# 生成数据集
def gen_data_set(sample_size, class_size):
    x_k = []
    y_k = []

    for k in range(6):
        x = []
        y = []

        if k == 0:
            for i in range(sample_size):
                m = np.random.random() * 20
                n = np.random.random() * m
                x.append(m)
                y.append(n)

        if k == 1:
            for i in range(sample_size):
                m = np.random.random() * 30 + 20
                if m < 25:
                    n = np.random.random() * m
                else:
                    n = np.random.random() * ((-1) * m + 50)
                x.append(m)
                y.append(n)

        if k == 2:
            for i in range(sample_size):
                m = np.random.random() * 25 + 25
                n = np.random.random() * (2 * m - 50) + ((-1) * m + 50)
                x.append(m)
                y.append(n)

        if k == 3:
            for i in range(sample_size):
                m = np.random.random() * 30 + 20
                if m < 25:
                    n = np.random.random() * m + ((-1) * m + 50)
                else:
                    n = np.random.random() * (50 - m) + m
                x.append(m)
                y.append(n)

        if k == 4:
            for i in range(sample_size):
                m = np.random.random() * 20
                n = np.random.random() * m + ((-1) * m + 50)
                x.append(m)
                y.append(n)

        if k == 5:
            for i in range(sample_size):
                m = np.random.random() * 20
                n = np.random.random() * ((-2) * m + 50) + m
                x.append(m)
                y.append(n)

        # 记录坐标
        # for i in range(sample_size):
        #     print str(k) + "," + str(x[i]) + "," + str(y[i])

        x_k.append(x)
        y_k.append(y)

    return x_k, y_k


# 生成数据集并导入文件
def gen_data_to_file(train_file, test_file, class_size=6, train_set_size=10, test_set_size=5):
    train_sample_size = 8
    test_sample_size = 2
    train_x_k, train_y_k = gen_data_set(train_sample_size, class_size)
    test_x_k, test_y_k = gen_data_set(test_sample_size, class_size)

    # 数据集导入文件
    with open(train_file, 'w') as f:
        for i in range(class_size):
            for j in range(train_sample_size):
                f.write(str(i) + "," + str(train_x_k[i][j]) + "," + str(train_y_k[i][j]) + "\n")

    with open(test_file, 'w') as f:
        for i in range(class_size):
            for j in range(test_sample_size):
                f.write(str(i) + "," + str(test_x_k[i][j]) + "," + str(test_y_k[i][j]) + "\n")

并将生成的数据集导入文件。

最小平方误差法

# encoding:utf-8

import numpy as np
from numpy.linalg import inv


# 最小平方误差算法(Least Mean Square Error, LMSE)
# (1)输入参数:规范化增广样本矩阵x
# (2)求X的伪逆矩阵 x# = (xT x)-1 xT
# (3)设置初值c和B(1)
# (4)计算e(k),并进行可分性判别
# (5)计算W(k+1)、B(k+1)


def least_mean_square_error(data, b_1=0.1, c=1, k_n=100000):
    # x:规范化增广样本"矩阵"
    x = np.array(data, dtype=np.float64)
    x_sharp = inv(x.T.dot(x)).dot(x.T)
    row, col = x.shape

    # print x

    k = 1
    while k < k_n:
        if k % 10000 == 0:
            print "第" + str(k) + "次迭代..."

        if k == 1:
            # 设置初值B(1)
            b_k = (np.zeros(row) + b_1).T
            # b_k = x[:, 0]
            w_k = x_sharp.dot(b_k)
        else:
            b_k = b(k + 1, b_k, c, e_k)
            w_k = w(k + 1, w_k, c, x_sharp, e_k)

        # w_k = w(k+1, x_sharp, b_k)        # 方法二
        e_k = e(k + 1, x, w_k, b_k)

        if (e_k == 0).all():
            print "第" + str(k) + "次迭代... " + "线性可分,解为"
            print w_k
            return w_k
        elif (e_k >= 0).all():
            k = k + 1
            print "第" + str(k) + "次迭代... " + "线性可分,继续迭代可得最优解"
        elif (e_k <= 0).all():
            # print e_k
            print "第" + str(k) + "次迭代... " + "e_k < 0, 停止迭代,检查XW(k)"
            if (x.dot(w_k) > 0).all():
                print "线性可分,解为" + str(w_k)
                return w_k
            else:
                print "XW(k) > 0 不成立,故无解,算法结束"
                return None
        else:
            k = k + 1

    return None


def b(k, b_k, c, e_k):
    return b_k + c * (e_k + np.abs(e_k))


def e(k, x, w_k, b_k):
    return x.dot(w_k) - b_k


# 方法一
def w(k, w_k, c, x_sharp, e_k):
    return w_k + c * x_sharp.dot(np.abs(e_k))

# 方法二
# def w(k, x_sharp, b_k):
#     return x_sharp.dot(b_k)

i_j 分类法

# encoding:utf-8
# w_i / w_j 两分法

import numpy as np
from algorithm.LMSE import least_mean_square_error


# 训练得到能够分开任意两类i/j的判别函数d_k_k
def train_to_file(data_k, to_file):
    d_k_k = []
    i = 0
    j = 0
    # data_i:第i个模式类的增广样本的"序列"
    for data_i in data_k:
        d_i_k = []
        for data_j in data_k:
            d_i_j = []
            if data_i != data_j:
                print "----------------d_" + str(i) + "_" + str(j) + "----------------"
                data_i_j = data_i[:]
                for point in data_j:
                    # 规范化增广样本
                    data_i_j.append([x * (-1) for x in point])
                d_i_j = least_mean_square_error(data_i_j)
                if d_i_j is None:
                    # 如果无法求得判别函数
                    print
                    print "data_i_j不存在判别函数: "
                    for data in data_i_j:
                        print data
                        return None
            d_i_k.append(d_i_j)
            j = j + 1
            print
        d_k_k.append(d_i_k)
        i = i + 1
        j = 0

    # 写入文件
    i = 0
    j = 0
    with open(to_file, 'w') as f:
        for d_i_k in d_k_k:
            for d_i_j in d_i_k:
                f.write(str(i) + "," + str(j))
                for d in d_i_j:
                    f.write("," + str(d))
                f.write("\n")
                j = j + 1
            i = i + 1
            j = 0

    return d_k_k


# 用任意两类i/j的判别函数d_k_k,对模式x进行识别(模式x为增广矩阵)
def recognize(x, d_k_k, index):
    tag = 0
    for d_i_k in d_k_k:
        d_i_mat = np.array(d_i_k)
        if (d_i_mat.dot(x) > 0).all():
            output_str = ""
            if tag == index:
                output_str = "分类正确"
            else:
                output_str = "分类错误"

            print str(x) + "\t属于模式类" + str(tag) + "," + output_str
            return tag
        else:
            tag = tag + 1

    print str(x) + "\t属于IR区"
    return -1


def load_d_k_k(from_file):
    class_tag = []
    d_k_k = []
    d_i_k = []
    with open(from_file, 'r') as f:
        for line in f.readlines():
            d_i_j_str = line.strip().split(',')
            if d_i_j_str[0] not in class_tag:  # 出现了一个新的类
                class_tag.append(d_i_j_str[0])
                if len(class_tag) != 1:
                    d_k_k.append(d_i_k)
                    d_i_k = []
            if len(d_i_j_str) != 2:
                d_i_k.append(np.array([float(x) for x in d_i_j_str[2:]]))
        d_k_k.append(d_i_k)
    return d_k_k


def i_j_main(train_data_k, test_data_k, i_j_file):
    train_to_file(train_data_k, i_j_file)
    return load_d_k_k(i_j_file)

i_non_i 分类法

# encoding:utf-8
# w_i / non_w_i 两分法

from algorithm.LMSE import least_mean_square_error


# 训练得到每个模式类的判别函数d_k
def train_to_file(data_k, to_file):
    d_k = []
    index = 0
    for data_i in data_k:
        print "--------------------d_" + str(index) + "--------------------"
        # data_i:第i个模式类的增广样本的"序列"
        data = data_i[:]
        for data_other in data_k:
            if data_i != data_other:
                for point in data_other:
                    # 规范化增广样本
                    data.append([x * (-1) for x in point])
        d_i = least_mean_square_error(data,c=0.28)
        if d_i is not None:
            d_k.append(d_i)
        else:
            # 如果无法求得判别函数
            print
            print "data_" + str(index) + "不存在判别函数: " + "....    " + str(len(data))
            for data_element in data:
                print data_element
        index = index + 1
        print
    return d_k


# 用k个模式类的判别函数d_k,对模式x进行识别
def recognize(x, d_k):
    positive_tag = 0
    negative_tag = 0
    index = 0
    tag = 0

    for d_i in d_k:
        if d_i.dot(x) > 0:
            tag = index
            positive_tag = positive_tag + 1
            if positive_tag > 1:
                print "d_i(X) > 0 的条件超过一个,分类失效"
                return None
        else:
            negative_tag = negative_tag + 1
        index = index + 1

    if positive_tag == 0:
        print "d_i(X) > 0 的条件无法满足,分类失效"
        return None

    print "X属于模式类" + str(tag) + ":"
    print d_k[tag]
    return d_k[tag]


def i_non_i_main(train_data_k, test_data_k, i_j_file):
    train_to_file(train_data_k, i_j_file)

分段线性判别法

在分段线性判别法中,我们假设共有父类0-2,且每个父类均有2个子类。具体实现如下:

# encoding:utf-8
# 分段线性判别函数:已知子类的划分

import numpy as np

from algorithm.LMSE import least_mean_square_error


def train_to_file(data_k, to_file):
    d_k = []
    index = 0
    for i in range(0, 6, 2):
        print "------------划分父类_" + str(index) + "------------"
        data_division_1 = data_k[i][:]
        data_division_2 = data_k[i + 1][:]
        data_i = data_division_1
        for point in data_division_2:
            data_i.append([x * (-1) for x in point])
        d_i = least_mean_square_error(data_i, b_1=1, c=0.8)
        if d_i is None:
            # 如果无法求得判别函数
            print
            print "data_i不存在判别函数: "
            for data in data_i:
                print data
                # return None
        d_k.append(d_i)
        index += 1
        print

    # 写入文件
    index = 0
    with open(to_file, 'w') as f:
        for d_i in d_k:
            f.write(str(index))
            for d in d_i:
                f.write("," + str(d))
            f.write("\n")
            index += 1

    return d_k


def recognize(x, d_k, index):
    d_0_division_1 = d_k[0].dot(x)
    d_0_division_2 = ((-1) * d_k[0]).dot(x)
    d_1_division_1 = d_k[1].dot(x)
    d_1_division_2 = ((-1) * d_k[1]).dot(x)
    d_2_division_1 = d_k[2].dot(x)
    d_2_division_2 = ((-1) * d_k[2]).dot(x)

    # 首先用子类代替父类
    d_0 = d_0_division_1 if d_0_division_1 > d_0_division_2 else d_0_division_2
    d_1 = d_1_division_1 if d_1_division_1 > d_1_division_2 else d_1_division_2
    d_2 = d_2_division_1 if d_2_division_1 > d_2_division_2 else d_2_division_2

    # 再用父类判别
    tag = -1
    if d_0 > d_1 and d_0 > d_2:
        tag = 0
    if d_1 > d_0 and d_1 > d_2:
        tag = 1
    if d_2 > d_0 and d_2 > d_1:
        tag = 2

    if tag == index:
        output_str = "分类正确"
    elif tag == -1:
        output_str = "IR区"
    else:
        output_str = "分类错误"

    print str(x) + "\t属于模式类" + str(tag) + "," + output_str
    return tag


def load_d_k(from_file):
    class_tag = []
    d_k = []
    with open(from_file, 'r') as f:
        for line in f.readlines():
            d_i_str = line.strip().split(',')
            d_k.append(np.array([float(x) for x in d_i_str[1:]]))
    return d_k


def class_division_main(train_data_k, test_data_k, class_division_file):
    train_to_file(train_data_k, class_division_file)
    return load_d_k(class_division_file)

测试结果

数据集

我们采用的训练集为:

0,18.5029168459,16.8228372586
0,15.1155598703,7.63485856842
0,6.95647929713,5.11539442509
0,0.597347656694,0.288674350297
0,6.36715034554,5.60396174026
0,15.7424995772,8.58485907293
0,18.0807236204,11.3012359879
0,10.1722692483,4.65541515255
1,47.8757125439,1.16332889424
1,26.0643690758,5.61636639878
1,31.1129222794,17.8499786928
1,23.5410782063,19.9956964598
1,26.6069909527,10.9217740638
1,39.7166244327,10.2806301464
1,49.5181370187,0.375776889409
1,49.2415167242,0.707007895968
2,37.8611287104,31.5632667325
2,27.6685288498,23.2728667808
2,47.4112562524,29.4781641683
2,37.9726890343,36.0084457108
2,37.0258232578,24.4223677169
2,44.5991664403,20.8661377571
2,49.2948779017,39.4635364912
2,32.0020253605,19.3541047235
3,44.2110390945,44.4664083051
3,38.3651365989,42.9632379304
3,35.960724763,44.7193712254
3,26.4532769283,34.8314226887
3,48.9018927779,49.5881169749
3,26.6484189936,31.2731506208
3,41.9097085676,47.9449438728
3,35.331013959,48.0256107217
4,3.81087608968,47.13431803
4,15.8810932052,35.593631285
4,5.46591652097,48.2716875232
4,18.7156249506,42.1656191735
4,13.4111687002,39.336560121
4,8.82628702024,49.9086481899
4,17.6593190139,48.6801651249
4,17.4355973285,45.2334103922
5,2.32851634741,7.54190740531
5,0.192342631621,17.2004066838
5,8.26038228069,29.4464752936
5,15.4799873459,34.2505881121
5,13.3166961224,32.4399699833
5,16.5491387156,18.6037637969
5,10.6585086421,21.9943090312
5,12.1629069128,34.3220738065

测试集为:

0,12.3228402554,2.9405381601
0,14.1333891941,0.65068883671
1,47.7641004258,0.797181091341
1,44.9040483706,3.78286547533
2,33.6360204882,20.0574966127
2,43.5656985063,25.4610595264
3,35.3512989762,40.962956587
3,42.7841966519,43.4279532924
4,19.2747318386,31.3287647744
4,7.48232294478,47.1481142472
5,16.0034786096,31.2251124134
5,0.346491676919,25.5140238588

数据集为二维模式,第一个标号为模式所属的类别,其余两个为模式的坐标。

i_j 分类法

采用LMSE算法,判别函数的训练结果如下:

----------------d_0_1----------------
第3088次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.03574661 -0.00627283  0.86694337]

----------------d_0_2----------------
第1484次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.01770862 -0.00584334  0.52596256]

----------------d_0_3----------------
第1177次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.01803347 -0.02400582  0.17017424]

----------------d_0_4----------------
第650次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.00532768 -0.0099107   0.16814849]

----------------d_0_5----------------
第5069次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.05118984 -0.05614292  0.09732183]

----------------d_1_0----------------
第3244次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.03574661  0.00627283 -0.86694337]


----------------d_1_2----------------
第10000次迭代...20000次迭代...22352次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.06312515 -0.09565371  3.77142462]

----------------d_1_3----------------
第1286次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.00263534 -0.01700837  0.50213301]

----------------d_1_4----------------
第1823次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.00410129 -0.01483631  0.49321121]

----------------d_1_5----------------
第2199次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.02822048  0.00192831 -0.60289858]

----------------d_2_0----------------
第1552次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.01770862  0.00584334 -0.52596256]

----------------d_2_1----------------
第10000次迭代...20000次迭代...20438次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.06312515  0.09565371 -3.77142462]


----------------d_2_3----------------
第10000次迭代...13260次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.03983482 -0.05302738  0.4967981 ]

----------------d_2_4----------------
第860次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.00924338 -0.0073895   0.01622401]

----------------d_2_5----------------
第1067次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.01752869  0.00109051 -0.41037223]

----------------d_3_0----------------
第2266次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.01803347  0.02400582 -0.17017424]

----------------d_3_1----------------
第1289次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.00263534  0.01700837 -0.50213301]

----------------d_3_2----------------
第3888次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.03983482  0.05302738 -0.4967981 ]


----------------d_3_4----------------
第848次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.01818381 -0.0101775  -0.02652464]

----------------d_3_5----------------
第887次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.01823988  0.00124634 -0.42504089]

----------------d_4_0----------------
第439次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.00532768  0.0099107  -0.16814849]

----------------d_4_1----------------
第1578次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.00410129  0.01483631 -0.49321121]

----------------d_4_2----------------
第549次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.00924338  0.0073895  -0.01622401]

----------------d_4_3----------------
第1420次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.01818381  0.0101775   0.02652464]


----------------d_4_5----------------
第10000次迭代...20000次迭代...30000次迭代...37054次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.06526169  0.12942484 -5.54312717]

----------------d_5_0----------------
第5050次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.05118984  0.05614292 -0.09732183]

----------------d_5_1----------------
第2199次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.02822048 -0.00192831  0.60289858]

----------------d_5_2----------------
第1303次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.01752869 -0.00109051  0.41037223]

----------------d_5_3----------------
第1005次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.01823988 -0.00124634  0.42504089]

----------------d_5_4----------------
第10000次迭代...20000次迭代...30000次迭代...37479次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.06526169 -0.12942484  5.54312717]


[12.3228402554, 2.9405381601, 1]    属于模式类0,分类正确
[14.1333891941, 0.65068883671, 1]   属于模式类0,分类正确
[47.7641004258, 0.797181091341, 1]  属于模式类1,分类正确
[44.9040483706, 3.78286547533, 1]   属于模式类1,分类正确
[33.6360204882, 20.0574966127, 1]   属于模式类2,分类正确
[43.5656985063, 25.4610595264, 1]   属于模式类2,分类正确
[35.3512989762, 40.962956587, 1]    属于模式类3,分类正确
[42.7841966519, 43.4279532924, 1]   属于模式类3,分类正确
[19.2747318386, 31.3287647744, 1]   属于IR区
[7.48232294478, 47.1481142472, 1]   属于模式类4,分类正确
[16.0034786096, 31.2251124134, 1]   属于模式类5,分类正确
[0.346491676919, 25.5140238588, 1]  属于模式类5,分类正确

识别正确率为:0.916666666667

训练数据、测试数据、识别结果如下图所示:

i_non_i 分类法

--------------------d_0--------------------
第1748次迭代... e_k < 0, 停止迭代,检查XW(k)
XW(k) > 0 不成立,故无解,算法结束

data_0不存在判别函数: ....    48
[18.5029168459, 16.8228372586, 1]
[15.1155598703, 7.63485856842, 1]
[6.95647929713, 5.11539442509, 1]
[0.597347656694, 0.288674350297, 1]
[6.36715034554, 5.60396174026, 1]
[15.7424995772, 8.58485907293, 1]
[18.0807236204, 11.3012359879, 1]
[10.1722692483, 4.65541515255, 1]
[-47.8757125439, -1.16332889424, -1]
[-26.0643690758, -5.61636639878, -1]
[-31.1129222794, -17.8499786928, -1]
[-23.5410782063, -19.9956964598, -1]
[-26.6069909527, -10.9217740638, -1]
[-39.7166244327, -10.2806301464, -1]
[-49.5181370187, -0.375776889409, -1]
[-49.2415167242, -0.707007895968, -1]
[-37.8611287104, -31.5632667325, -1]
[-27.6685288498, -23.2728667808, -1]
[-47.4112562524, -29.4781641683, -1]
[-37.9726890343, -36.0084457108, -1]
[-37.0258232578, -24.4223677169, -1]
[-44.5991664403, -20.8661377571, -1]
[-49.2948779017, -39.4635364912, -1]
[-32.0020253605, -19.3541047235, -1]
[-44.2110390945, -44.4664083051, -1]
[-38.3651365989, -42.9632379304, -1]
[-35.960724763, -44.7193712254, -1]
[-26.4532769283, -34.8314226887, -1]
[-48.9018927779, -49.5881169749, -1]
[-26.6484189936, -31.2731506208, -1]
[-41.9097085676, -47.9449438728, -1]
[-35.331013959, -48.0256107217, -1]
[-3.81087608968, -47.13431803, -1]
[-15.8810932052, -35.593631285, -1]
[-5.46591652097, -48.2716875232, -1]
[-18.7156249506, -42.1656191735, -1]
[-13.4111687002, -39.336560121, -1]
[-8.82628702024, -49.9086481899, -1]
[-17.6593190139, -48.6801651249, -1]
[-17.4355973285, -45.2334103922, -1]
[-2.32851634741, -7.54190740531, -1]
[-0.192342631621, -17.2004066838, -1]
[-8.26038228069, -29.4464752936, -1]
[-15.4799873459, -34.2505881121, -1]
[-13.3166961224, -32.4399699833, -1]
[-16.5491387156, -18.6037637969, -1]
[-10.6585086421, -21.9943090312, -1]
[-12.1629069128, -34.3220738065, -1]

--------------------d_1--------------------
第3180次迭代... e_k < 0, 停止迭代,检查XW(k)
XW(k) > 0 不成立,故无解,算法结束

data_1不存在判别函数: ....    48
[47.8757125439, 1.16332889424, 1]
[26.0643690758, 5.61636639878, 1]
[31.1129222794, 17.8499786928, 1]
[23.5410782063, 19.9956964598, 1]
[26.6069909527, 10.9217740638, 1]
[39.7166244327, 10.2806301464, 1]
[49.5181370187, 0.375776889409, 1]
[49.2415167242, 0.707007895968, 1]
[-18.5029168459, -16.8228372586, -1]
[-15.1155598703, -7.63485856842, -1]
[-6.95647929713, -5.11539442509, -1]
[-0.597347656694, -0.288674350297, -1]
[-6.36715034554, -5.60396174026, -1]
[-15.7424995772, -8.58485907293, -1]
[-18.0807236204, -11.3012359879, -1]
[-10.1722692483, -4.65541515255, -1]
[-37.8611287104, -31.5632667325, -1]
[-27.6685288498, -23.2728667808, -1]
[-47.4112562524, -29.4781641683, -1]
[-37.9726890343, -36.0084457108, -1]
[-37.0258232578, -24.4223677169, -1]
[-44.5991664403, -20.8661377571, -1]
[-49.2948779017, -39.4635364912, -1]
[-32.0020253605, -19.3541047235, -1]
[-44.2110390945, -44.4664083051, -1]
[-38.3651365989, -42.9632379304, -1]
[-35.960724763, -44.7193712254, -1]
[-26.4532769283, -34.8314226887, -1]
[-48.9018927779, -49.5881169749, -1]
[-26.6484189936, -31.2731506208, -1]
[-41.9097085676, -47.9449438728, -1]
[-35.331013959, -48.0256107217, -1]
[-3.81087608968, -47.13431803, -1]
[-15.8810932052, -35.593631285, -1]
[-5.46591652097, -48.2716875232, -1]
[-18.7156249506, -42.1656191735, -1]
[-13.4111687002, -39.336560121, -1]
[-8.82628702024, -49.9086481899, -1]
[-17.6593190139, -48.6801651249, -1]
[-17.4355973285, -45.2334103922, -1]
[-2.32851634741, -7.54190740531, -1]
[-0.192342631621, -17.2004066838, -1]
[-8.26038228069, -29.4464752936, -1]
[-15.4799873459, -34.2505881121, -1]
[-13.3166961224, -32.4399699833, -1]
[-16.5491387156, -18.6037637969, -1]
[-10.6585086421, -21.9943090312, -1]
[-12.1629069128, -34.3220738065, -1]

--------------------d_2--------------------
第148次迭代... e_k < 0, 停止迭代,检查XW(k)
XW(k) > 0 不成立,故无解,算法结束

data_2不存在判别函数: ....    48
[37.8611287104, 31.5632667325, 1]
[27.6685288498, 23.2728667808, 1]
[47.4112562524, 29.4781641683, 1]
[37.9726890343, 36.0084457108, 1]
[37.0258232578, 24.4223677169, 1]
[44.5991664403, 20.8661377571, 1]
[49.2948779017, 39.4635364912, 1]
[32.0020253605, 19.3541047235, 1]
[-18.5029168459, -16.8228372586, -1]
[-15.1155598703, -7.63485856842, -1]
[-6.95647929713, -5.11539442509, -1]
[-0.597347656694, -0.288674350297, -1]
[-6.36715034554, -5.60396174026, -1]
[-15.7424995772, -8.58485907293, -1]
[-18.0807236204, -11.3012359879, -1]
[-10.1722692483, -4.65541515255, -1]
[-47.8757125439, -1.16332889424, -1]
[-26.0643690758, -5.61636639878, -1]
[-31.1129222794, -17.8499786928, -1]
[-23.5410782063, -19.9956964598, -1]
[-26.6069909527, -10.9217740638, -1]
[-39.7166244327, -10.2806301464, -1]
[-49.5181370187, -0.375776889409, -1]
[-49.2415167242, -0.707007895968, -1]
[-44.2110390945, -44.4664083051, -1]
[-38.3651365989, -42.9632379304, -1]
[-35.960724763, -44.7193712254, -1]
[-26.4532769283, -34.8314226887, -1]
[-48.9018927779, -49.5881169749, -1]
[-26.6484189936, -31.2731506208, -1]
[-41.9097085676, -47.9449438728, -1]
[-35.331013959, -48.0256107217, -1]
[-3.81087608968, -47.13431803, -1]
[-15.8810932052, -35.593631285, -1]
[-5.46591652097, -48.2716875232, -1]
[-18.7156249506, -42.1656191735, -1]
[-13.4111687002, -39.336560121, -1]
[-8.82628702024, -49.9086481899, -1]
[-17.6593190139, -48.6801651249, -1]
[-17.4355973285, -45.2334103922, -1]
[-2.32851634741, -7.54190740531, -1]
[-0.192342631621, -17.2004066838, -1]
[-8.26038228069, -29.4464752936, -1]
[-15.4799873459, -34.2505881121, -1]
[-13.3166961224, -32.4399699833, -1]
[-16.5491387156, -18.6037637969, -1]
[-10.6585086421, -21.9943090312, -1]
[-12.1629069128, -34.3220738065, -1]

--------------------d_3--------------------
第1207次迭代... e_k < 0, 停止迭代,检查XW(k)
XW(k) > 0 不成立,故无解,算法结束

data_3不存在判别函数: ....    48
[44.2110390945, 44.4664083051, 1]
[38.3651365989, 42.9632379304, 1]
[35.960724763, 44.7193712254, 1]
[26.4532769283, 34.8314226887, 1]
[48.9018927779, 49.5881169749, 1]
[26.6484189936, 31.2731506208, 1]
[41.9097085676, 47.9449438728, 1]
[35.331013959, 48.0256107217, 1]
[-18.5029168459, -16.8228372586, -1]
[-15.1155598703, -7.63485856842, -1]
[-6.95647929713, -5.11539442509, -1]
[-0.597347656694, -0.288674350297, -1]
[-6.36715034554, -5.60396174026, -1]
[-15.7424995772, -8.58485907293, -1]
[-18.0807236204, -11.3012359879, -1]
[-10.1722692483, -4.65541515255, -1]
[-47.8757125439, -1.16332889424, -1]
[-26.0643690758, -5.61636639878, -1]
[-31.1129222794, -17.8499786928, -1]
[-23.5410782063, -19.9956964598, -1]
[-26.6069909527, -10.9217740638, -1]
[-39.7166244327, -10.2806301464, -1]
[-49.5181370187, -0.375776889409, -1]
[-49.2415167242, -0.707007895968, -1]
[-37.8611287104, -31.5632667325, -1]
[-27.6685288498, -23.2728667808, -1]
[-47.4112562524, -29.4781641683, -1]
[-37.9726890343, -36.0084457108, -1]
[-37.0258232578, -24.4223677169, -1]
[-44.5991664403, -20.8661377571, -1]
[-49.2948779017, -39.4635364912, -1]
[-32.0020253605, -19.3541047235, -1]
[-3.81087608968, -47.13431803, -1]
[-15.8810932052, -35.593631285, -1]
[-5.46591652097, -48.2716875232, -1]
[-18.7156249506, -42.1656191735, -1]
[-13.4111687002, -39.336560121, -1]
[-8.82628702024, -49.9086481899, -1]
[-17.6593190139, -48.6801651249, -1]
[-17.4355973285, -45.2334103922, -1]
[-2.32851634741, -7.54190740531, -1]
[-0.192342631621, -17.2004066838, -1]
[-8.26038228069, -29.4464752936, -1]
[-15.4799873459, -34.2505881121, -1]
[-13.3166961224, -32.4399699833, -1]
[-16.5491387156, -18.6037637969, -1]
[-10.6585086421, -21.9943090312, -1]
[-12.1629069128, -34.3220738065, -1]

--------------------d_4--------------------
第10000次迭代...
第20000次迭代...
第30000次迭代...
第40000次迭代...
第50000次迭代...
第52655次迭代... e_k < 0, 停止迭代,检查XW(k)
XW(k) > 0 不成立,故无解,算法结束

data_4不存在判别函数: ....    48
[3.81087608968, 47.13431803, 1]
[15.8810932052, 35.593631285, 1]
[5.46591652097, 48.2716875232, 1]
[18.7156249506, 42.1656191735, 1]
[13.4111687002, 39.336560121, 1]
[8.82628702024, 49.9086481899, 1]
[17.6593190139, 48.6801651249, 1]
[17.4355973285, 45.2334103922, 1]
[-18.5029168459, -16.8228372586, -1]
[-15.1155598703, -7.63485856842, -1]
[-6.95647929713, -5.11539442509, -1]
[-0.597347656694, -0.288674350297, -1]
[-6.36715034554, -5.60396174026, -1]
[-15.7424995772, -8.58485907293, -1]
[-18.0807236204, -11.3012359879, -1]
[-10.1722692483, -4.65541515255, -1]
[-47.8757125439, -1.16332889424, -1]
[-26.0643690758, -5.61636639878, -1]
[-31.1129222794, -17.8499786928, -1]
[-23.5410782063, -19.9956964598, -1]
[-26.6069909527, -10.9217740638, -1]
[-39.7166244327, -10.2806301464, -1]
[-49.5181370187, -0.375776889409, -1]
[-49.2415167242, -0.707007895968, -1]
[-37.8611287104, -31.5632667325, -1]
[-27.6685288498, -23.2728667808, -1]
[-47.4112562524, -29.4781641683, -1]
[-37.9726890343, -36.0084457108, -1]
[-37.0258232578, -24.4223677169, -1]
[-44.5991664403, -20.8661377571, -1]
[-49.2948779017, -39.4635364912, -1]
[-32.0020253605, -19.3541047235, -1]
[-44.2110390945, -44.4664083051, -1]
[-38.3651365989, -42.9632379304, -1]
[-35.960724763, -44.7193712254, -1]
[-26.4532769283, -34.8314226887, -1]
[-48.9018927779, -49.5881169749, -1]
[-26.6484189936, -31.2731506208, -1]
[-41.9097085676, -47.9449438728, -1]
[-35.331013959, -48.0256107217, -1]
[-2.32851634741, -7.54190740531, -1]
[-0.192342631621, -17.2004066838, -1]
[-8.26038228069, -29.4464752936, -1]
[-15.4799873459, -34.2505881121, -1]
[-13.3166961224, -32.4399699833, -1]
[-16.5491387156, -18.6037637969, -1]
[-10.6585086421, -21.9943090312, -1]
[-12.1629069128, -34.3220738065, -1]

--------------------d_5--------------------
第534次迭代... e_k < 0, 停止迭代,检查XW(k)
XW(k) > 0 不成立,故无解,算法结束

data_5不存在判别函数: ....    48
[2.32851634741, 7.54190740531, 1]
[0.192342631621, 17.2004066838, 1]
[8.26038228069, 29.4464752936, 1]
[15.4799873459, 34.2505881121, 1]
[13.3166961224, 32.4399699833, 1]
[16.5491387156, 18.6037637969, 1]
[10.6585086421, 21.9943090312, 1]
[12.1629069128, 34.3220738065, 1]
[-18.5029168459, -16.8228372586, -1]
[-15.1155598703, -7.63485856842, -1]
[-6.95647929713, -5.11539442509, -1]
[-0.597347656694, -0.288674350297, -1]
[-6.36715034554, -5.60396174026, -1]
[-15.7424995772, -8.58485907293, -1]
[-18.0807236204, -11.3012359879, -1]
[-10.1722692483, -4.65541515255, -1]
[-47.8757125439, -1.16332889424, -1]
[-26.0643690758, -5.61636639878, -1]
[-31.1129222794, -17.8499786928, -1]
[-23.5410782063, -19.9956964598, -1]
[-26.6069909527, -10.9217740638, -1]
[-39.7166244327, -10.2806301464, -1]
[-49.5181370187, -0.375776889409, -1]
[-49.2415167242, -0.707007895968, -1]
[-37.8611287104, -31.5632667325, -1]
[-27.6685288498, -23.2728667808, -1]
[-47.4112562524, -29.4781641683, -1]
[-37.9726890343, -36.0084457108, -1]
[-37.0258232578, -24.4223677169, -1]
[-44.5991664403, -20.8661377571, -1]
[-49.2948779017, -39.4635364912, -1]
[-32.0020253605, -19.3541047235, -1]
[-44.2110390945, -44.4664083051, -1]
[-38.3651365989, -42.9632379304, -1]
[-35.960724763, -44.7193712254, -1]
[-26.4532769283, -34.8314226887, -1]
[-48.9018927779, -49.5881169749, -1]
[-26.6484189936, -31.2731506208, -1]
[-41.9097085676, -47.9449438728, -1]
[-35.331013959, -48.0256107217, -1]
[-3.81087608968, -47.13431803, -1]
[-15.8810932052, -35.593631285, -1]
[-5.46591652097, -48.2716875232, -1]
[-18.7156249506, -42.1656191735, -1]
[-13.4111687002, -39.336560121, -1]
[-8.82628702024, -49.9086481899, -1]
[-17.6593190139, -48.6801651249, -1]
[-17.4355973285, -45.2334103922, -1]

故采用i_non_i分类法,无法训练出样本数据的判别函数

采用LMSE算法,通过i_non_i 分类法,多次更改算法因子B(1)c,均未训练出样本数据的判别函数。

分段线性判别法

在已知子类划分的情况下,分段线性判别法的输出结果如下:

------------划分父类_0------------
第3846次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.35746611 -0.0627283   8.66943368]

------------划分父类_1------------
第3948次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.39834821 -0.53027375  4.96798099]

------------划分父类_2------------
第10000次迭代...20000次迭代...30000次迭代...40000次迭代...46418次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[  0.6524255    1.29430561 -55.43026674]

[12.3228402554, 2.9405381601, 1]    属于模式类2,分类错误
[14.1333891941, 0.65068883671, 1]   属于模式类2,分类错误
[47.7641004258, 0.797181091341, 1]  属于模式类1,分类错误
[44.9040483706, 3.78286547533, 1]   属于模式类2,分类错误
[33.6360204882, 20.0574966127, 1]   属于模式类1,分类正确
[43.5656985063, 25.4610595264, 1]   属于模式类1,分类正确
[35.3512989762, 40.962956587, 1]    属于模式类2,分类错误
[42.7841966519, 43.4279532924, 1]   属于模式类2,分类错误
[19.2747318386, 31.3287647744, 1]   属于模式类1,分类错误
[7.48232294478, 47.1481142472, 1]   属于模式类1,分类错误
[16.0034786096, 31.2251124134, 1]   属于模式类1,分类错误
[0.346491676919, 25.5140238588, 1]  属于模式类2,分类正确

识别正确率为:0.25

实验结果如下图所示:

如上的四个子图中,train_data显示了三个父类的训练数据集,class_division是已知的子类划分方式,discriminant_func是训练所得的判别函数(3个父类,每个父类有2个子类),test_data是判别函数对测试集的识别结果。


小结

在本篇博客中,我们学习了判别函数与几何分类法的思想精髓,也对LMSE算法在多分类条件下进行模式识别的整个过程加以实现。核心内容有以下几点:

  • 应明确多分类的三种情况。
  • 感知器算法与梯度算法在迭代运算时,算法有可能始终不收敛,因此在程序运行的过程中,我们无法得知,这究竟是因为迭代次数还不够,还是因为数据集本身的就线性不可分。而最小平方误差法改善了这一困境。
  • 学习分段线性判别法在已知子类划分情况下的判别过程。

在上述实践的过程中,值得我们思考的地方还有:

  • 尽管随机生成了数据集,但为确保模式集线性可分,在此实验中是基于事先给定的准线模拟生成数据的。而在现实的应用场景下,多数模式集不一定线性可分。
  • 实验数据在i_non_i 分类法下,并未训练出相应的判别函数,这也反映了i_non_ii_j 两种分类法之间的特性:虽然i_j两分法则比i_non_i 两分法需要训练出更多的判别函数,但其对模式线性可分的可能性更大一点,因为其不需要考虑其他模式类的影响,受到的限制条件少。这是i_j 两分法的主要优点。
  • 在实验中,分段线性判别法训练出的判别函数,识别正确率只有25%,究其原因,是父类中所含的子类数目少,因此训练出的判别函数不能具有很强的代表性。且对于非线性判别法而言,分段线性判别法本身较为简单,在现实应用中,为追求识别效果,还应考虑如势函数法等其他算法。

参考资料

[1] 《模式识别导论》. 齐敏、李大健、郝重阳编著.

[2]《利用Python进行数据分析》. Wes Mckinney著.

你可能感兴趣的:(模式识别)