本篇博客对模式识别中的“判别函数与几何分类法”进行了讨论,并重点实现了最小平方误差法(Least Mean Square Error, LMSE)算法,也对不同的多类判别法、分段线性判别法进行了一定的实现,并通过随机生成的数据集加以测试。
项目源码及所使用的数据集参见:PR-EXPT2
我们首先了解判别函数与几何分类法的基本知识,在此目中不追求科学语言的精确性,旨在用形式化的表达阐述其大致的思想。
上图呈现了典型的几何分类法的特点,其中左侧为线性判别法,右侧为非线性判别法,蓝色的直线与绿色的圆圈就称为判别函数。
若将二维模式推广至n维,则线性判别函数的一般形式为:
其中,X
为增广模式向量,W
为增广权向量。
我们以两类情况为例,对于模式在w1
、w2
中的两类情况,线性判别法的判别方式为:
当d(X) = 0
时,为不可判别情况。
至此,我们已经知道了用判别函数来进行几何分类法的整个过程,即:
d(X)
:即求出增广权向量W
X'
,我们能够用判别函数进行相应判别.上述过程的第二步:“用判别函数进行识别”是非常简单的,只需要做两个矩阵的乘法运算。关键是第一步,我们该怎么得到判别函数呢,即该怎么求这个W
呢?
可以很直观的想到,判别函数一定是通过训练得到的,即用已有的、已知其所属类别的模式,学习得到的。
不妨假设对已知的w_k
类来说,这个类里面有X1, X2, ..., Xn
这n个模式,那么我们所要求解的W
,首先一定是要满足W * Xi > 0 (i = 1, 2, ..., n)
这个不等式组的。除此之外,还需要满足什么额外条件呢?
而这些额外条件,恰恰是与我们的类别数有关的。
如果整个模式集只有两类,w1
、w2
,假设这两个类的判别函数分别为d1(X)
、d2(X)
,那么如果一个模式X0
属于w1
类,就需要满足:
d1(X0) > 0
,即这个模式在第一类里d2(X0) < 0
,即这个模式不在第二类里那么再回到我们刚才讨论的“额外条件”问题:
假设w1
类里有模式X1_1, X1_2, ..., X1_n
,w2
类有模式X2_1, X2_2, ..., X2_n
,那么对于判别函数d1
的训练,我们首先需要:
d1(X1_i) > 0 (i = 1, 2, ..., n)
其次,还需要满足额外条件:d1(X2_i) < 0
,即:
d1(- X2_i) > 0 (i = 1, 2, ..., n)
那么现在如果模式集有多类,额外条件又该如何确定呢?
在这里我们介绍两种方法:
i_non_i
两分法,或者“正负类”法,即对于类wi
来说,我们把所有的模式集分为两种:一种是属于wi
的,另一种是不属于wi
的。那么求解判别函数di(X)
的不等式组为:
di(Xi_j) > 0 (j = 1, 2, ..., n)
di(- Xk_j) > 0 (j = 1, 2, ..., n; k为其他的模式类)
i_j
两分法,即对于类wi
与类wj
来说,这就成了一个二分类的问题。因此,若假设总共有k
个模式类,那么如果要把类wi
与其他k-1
个模式类区分开,就需要求解k-1
个判别函数di_1, di_2, ..., di_k-1
。
此时,到了”用判别函数进行识别”的时候,若模式X0
属于wi
类,就需要满足:di_j(X0) > 0
,即对于以i
打头的所有判别函数,其判别值均需为正。
到了此处,我们已经明确了整个几何分类法的过程,也知道了在训练判别函数的时候,所要满足的不等式组是什么。那么这个不等式组又该如何求解呢?
这就涉及到了不同的线性分类算法。
虽然这些算法在形式上看起来十分复杂,但我们只需要记住一点:这些算法其实就是在解一个不等式方程组。
我们可以发散地想,这些算法甚至是独立于“模式识别”而存在的。假如你在线性代数的课堂上遇到了一道解不等式方程组的问题,就可以通过这些算法来解决它。其之所以应用在模式识别的领域上,只不过因为,模式识别的几何分类法需要解不等式方程组
(即:求判别函数)。
下面我们列举了几种经典的线性分类算法。
在实际应用中,我们应该明确这三种算法之间的关系:
因此我们可以看到最小平方误差法的优势所在,本篇博客也主要实现了最小平方误差法。
在此我们简要介绍分段线性判别法的思想。
1、请实现以下功能之一
1)编写程序,使用感知器算法、梯度法和最小平方误差法中的一种算法,
训练线性判别函数并进行数据分类,分类时采用多分类情况的两种。2)编写程序,使用感知器算法、梯度法和最小平方误差法中的两种算法,
训练线性判别函数并进行数据分类,分类时采用多分类情况的一种。样本数据请自拟并不能少于30个数据样本,类别不少于5类。
其中部分样本用于训练判别函数,部分用于验证,并计算出正确率
2、请编写程序实现分段线性判别法进行数据分类。
样本数据请自拟并不能少于30个数据样本,类别不少于3类,每类至少有2个以上子类。
实验的实现共分为以下几个过程:
数据集的获取。在本实验中,随机生成了模式数据集。训练样本共分为6类,每类有8个模式;测试样本共分为6类,每类2个模式。
最小平方误差法(Least Mean Square Error, LMSE)算法的实现:
最小平方误差算法(Least Mean Square Error, LMSE)
(1)输入参数:规范化增广样本矩阵x
(2)求X的伪逆矩阵 x# = (xT x)-1 xT
(3)设置初值c和B(1)
(4)计算e(k),并进行可分性判别
(5)计算W(k+1)、B(k+1)
采用i_non_i
分类法训练判别函数,并评测识别正确率
采用i_j
分类法训练判别函数,并评测识别正确率
实现分段线性判别法,训练子类的判别函数,并评测识别正确率
操作系统:MacOS X
开发语言:Python 2.7
如上图所示,在[0, 50]
的范围内随机生成模式集,并具体划分在0, 1, 2, 3, 4, 5
这6个类中。
具体实现如下:
# encoding:utf-8
# 生成数据集
import numpy as np
# 生成数据集
def gen_data_set(sample_size, class_size):
x_k = []
y_k = []
for k in range(6):
x = []
y = []
if k == 0:
for i in range(sample_size):
m = np.random.random() * 20
n = np.random.random() * m
x.append(m)
y.append(n)
if k == 1:
for i in range(sample_size):
m = np.random.random() * 30 + 20
if m < 25:
n = np.random.random() * m
else:
n = np.random.random() * ((-1) * m + 50)
x.append(m)
y.append(n)
if k == 2:
for i in range(sample_size):
m = np.random.random() * 25 + 25
n = np.random.random() * (2 * m - 50) + ((-1) * m + 50)
x.append(m)
y.append(n)
if k == 3:
for i in range(sample_size):
m = np.random.random() * 30 + 20
if m < 25:
n = np.random.random() * m + ((-1) * m + 50)
else:
n = np.random.random() * (50 - m) + m
x.append(m)
y.append(n)
if k == 4:
for i in range(sample_size):
m = np.random.random() * 20
n = np.random.random() * m + ((-1) * m + 50)
x.append(m)
y.append(n)
if k == 5:
for i in range(sample_size):
m = np.random.random() * 20
n = np.random.random() * ((-2) * m + 50) + m
x.append(m)
y.append(n)
# 记录坐标
# for i in range(sample_size):
# print str(k) + "," + str(x[i]) + "," + str(y[i])
x_k.append(x)
y_k.append(y)
return x_k, y_k
# 生成数据集并导入文件
def gen_data_to_file(train_file, test_file, class_size=6, train_set_size=10, test_set_size=5):
train_sample_size = 8
test_sample_size = 2
train_x_k, train_y_k = gen_data_set(train_sample_size, class_size)
test_x_k, test_y_k = gen_data_set(test_sample_size, class_size)
# 数据集导入文件
with open(train_file, 'w') as f:
for i in range(class_size):
for j in range(train_sample_size):
f.write(str(i) + "," + str(train_x_k[i][j]) + "," + str(train_y_k[i][j]) + "\n")
with open(test_file, 'w') as f:
for i in range(class_size):
for j in range(test_sample_size):
f.write(str(i) + "," + str(test_x_k[i][j]) + "," + str(test_y_k[i][j]) + "\n")
并将生成的数据集导入文件。
# encoding:utf-8
import numpy as np
from numpy.linalg import inv
# 最小平方误差算法(Least Mean Square Error, LMSE)
# (1)输入参数:规范化增广样本矩阵x
# (2)求X的伪逆矩阵 x# = (xT x)-1 xT
# (3)设置初值c和B(1)
# (4)计算e(k),并进行可分性判别
# (5)计算W(k+1)、B(k+1)
def least_mean_square_error(data, b_1=0.1, c=1, k_n=100000):
# x:规范化增广样本"矩阵"
x = np.array(data, dtype=np.float64)
x_sharp = inv(x.T.dot(x)).dot(x.T)
row, col = x.shape
# print x
k = 1
while k < k_n:
if k % 10000 == 0:
print "第" + str(k) + "次迭代..."
if k == 1:
# 设置初值B(1)
b_k = (np.zeros(row) + b_1).T
# b_k = x[:, 0]
w_k = x_sharp.dot(b_k)
else:
b_k = b(k + 1, b_k, c, e_k)
w_k = w(k + 1, w_k, c, x_sharp, e_k)
# w_k = w(k+1, x_sharp, b_k) # 方法二
e_k = e(k + 1, x, w_k, b_k)
if (e_k == 0).all():
print "第" + str(k) + "次迭代... " + "线性可分,解为"
print w_k
return w_k
elif (e_k >= 0).all():
k = k + 1
print "第" + str(k) + "次迭代... " + "线性可分,继续迭代可得最优解"
elif (e_k <= 0).all():
# print e_k
print "第" + str(k) + "次迭代... " + "e_k < 0, 停止迭代,检查XW(k)"
if (x.dot(w_k) > 0).all():
print "线性可分,解为" + str(w_k)
return w_k
else:
print "XW(k) > 0 不成立,故无解,算法结束"
return None
else:
k = k + 1
return None
def b(k, b_k, c, e_k):
return b_k + c * (e_k + np.abs(e_k))
def e(k, x, w_k, b_k):
return x.dot(w_k) - b_k
# 方法一
def w(k, w_k, c, x_sharp, e_k):
return w_k + c * x_sharp.dot(np.abs(e_k))
# 方法二
# def w(k, x_sharp, b_k):
# return x_sharp.dot(b_k)
i_j
分类法# encoding:utf-8
# w_i / w_j 两分法
import numpy as np
from algorithm.LMSE import least_mean_square_error
# 训练得到能够分开任意两类i/j的判别函数d_k_k
def train_to_file(data_k, to_file):
d_k_k = []
i = 0
j = 0
# data_i:第i个模式类的增广样本的"序列"
for data_i in data_k:
d_i_k = []
for data_j in data_k:
d_i_j = []
if data_i != data_j:
print "----------------d_" + str(i) + "_" + str(j) + "----------------"
data_i_j = data_i[:]
for point in data_j:
# 规范化增广样本
data_i_j.append([x * (-1) for x in point])
d_i_j = least_mean_square_error(data_i_j)
if d_i_j is None:
# 如果无法求得判别函数
print
print "data_i_j不存在判别函数: "
for data in data_i_j:
print data
return None
d_i_k.append(d_i_j)
j = j + 1
print
d_k_k.append(d_i_k)
i = i + 1
j = 0
# 写入文件
i = 0
j = 0
with open(to_file, 'w') as f:
for d_i_k in d_k_k:
for d_i_j in d_i_k:
f.write(str(i) + "," + str(j))
for d in d_i_j:
f.write("," + str(d))
f.write("\n")
j = j + 1
i = i + 1
j = 0
return d_k_k
# 用任意两类i/j的判别函数d_k_k,对模式x进行识别(模式x为增广矩阵)
def recognize(x, d_k_k, index):
tag = 0
for d_i_k in d_k_k:
d_i_mat = np.array(d_i_k)
if (d_i_mat.dot(x) > 0).all():
output_str = ""
if tag == index:
output_str = "分类正确"
else:
output_str = "分类错误"
print str(x) + "\t属于模式类" + str(tag) + "," + output_str
return tag
else:
tag = tag + 1
print str(x) + "\t属于IR区"
return -1
def load_d_k_k(from_file):
class_tag = []
d_k_k = []
d_i_k = []
with open(from_file, 'r') as f:
for line in f.readlines():
d_i_j_str = line.strip().split(',')
if d_i_j_str[0] not in class_tag: # 出现了一个新的类
class_tag.append(d_i_j_str[0])
if len(class_tag) != 1:
d_k_k.append(d_i_k)
d_i_k = []
if len(d_i_j_str) != 2:
d_i_k.append(np.array([float(x) for x in d_i_j_str[2:]]))
d_k_k.append(d_i_k)
return d_k_k
def i_j_main(train_data_k, test_data_k, i_j_file):
train_to_file(train_data_k, i_j_file)
return load_d_k_k(i_j_file)
i_non_i
分类法# encoding:utf-8
# w_i / non_w_i 两分法
from algorithm.LMSE import least_mean_square_error
# 训练得到每个模式类的判别函数d_k
def train_to_file(data_k, to_file):
d_k = []
index = 0
for data_i in data_k:
print "--------------------d_" + str(index) + "--------------------"
# data_i:第i个模式类的增广样本的"序列"
data = data_i[:]
for data_other in data_k:
if data_i != data_other:
for point in data_other:
# 规范化增广样本
data.append([x * (-1) for x in point])
d_i = least_mean_square_error(data,c=0.28)
if d_i is not None:
d_k.append(d_i)
else:
# 如果无法求得判别函数
print
print "data_" + str(index) + "不存在判别函数: " + ".... " + str(len(data))
for data_element in data:
print data_element
index = index + 1
print
return d_k
# 用k个模式类的判别函数d_k,对模式x进行识别
def recognize(x, d_k):
positive_tag = 0
negative_tag = 0
index = 0
tag = 0
for d_i in d_k:
if d_i.dot(x) > 0:
tag = index
positive_tag = positive_tag + 1
if positive_tag > 1:
print "d_i(X) > 0 的条件超过一个,分类失效"
return None
else:
negative_tag = negative_tag + 1
index = index + 1
if positive_tag == 0:
print "d_i(X) > 0 的条件无法满足,分类失效"
return None
print "X属于模式类" + str(tag) + ":"
print d_k[tag]
return d_k[tag]
def i_non_i_main(train_data_k, test_data_k, i_j_file):
train_to_file(train_data_k, i_j_file)
在分段线性判别法中,我们假设共有父类0-2,且每个父类均有2个子类。具体实现如下:
# encoding:utf-8
# 分段线性判别函数:已知子类的划分
import numpy as np
from algorithm.LMSE import least_mean_square_error
def train_to_file(data_k, to_file):
d_k = []
index = 0
for i in range(0, 6, 2):
print "------------划分父类_" + str(index) + "------------"
data_division_1 = data_k[i][:]
data_division_2 = data_k[i + 1][:]
data_i = data_division_1
for point in data_division_2:
data_i.append([x * (-1) for x in point])
d_i = least_mean_square_error(data_i, b_1=1, c=0.8)
if d_i is None:
# 如果无法求得判别函数
print
print "data_i不存在判别函数: "
for data in data_i:
print data
# return None
d_k.append(d_i)
index += 1
print
# 写入文件
index = 0
with open(to_file, 'w') as f:
for d_i in d_k:
f.write(str(index))
for d in d_i:
f.write("," + str(d))
f.write("\n")
index += 1
return d_k
def recognize(x, d_k, index):
d_0_division_1 = d_k[0].dot(x)
d_0_division_2 = ((-1) * d_k[0]).dot(x)
d_1_division_1 = d_k[1].dot(x)
d_1_division_2 = ((-1) * d_k[1]).dot(x)
d_2_division_1 = d_k[2].dot(x)
d_2_division_2 = ((-1) * d_k[2]).dot(x)
# 首先用子类代替父类
d_0 = d_0_division_1 if d_0_division_1 > d_0_division_2 else d_0_division_2
d_1 = d_1_division_1 if d_1_division_1 > d_1_division_2 else d_1_division_2
d_2 = d_2_division_1 if d_2_division_1 > d_2_division_2 else d_2_division_2
# 再用父类判别
tag = -1
if d_0 > d_1 and d_0 > d_2:
tag = 0
if d_1 > d_0 and d_1 > d_2:
tag = 1
if d_2 > d_0 and d_2 > d_1:
tag = 2
if tag == index:
output_str = "分类正确"
elif tag == -1:
output_str = "IR区"
else:
output_str = "分类错误"
print str(x) + "\t属于模式类" + str(tag) + "," + output_str
return tag
def load_d_k(from_file):
class_tag = []
d_k = []
with open(from_file, 'r') as f:
for line in f.readlines():
d_i_str = line.strip().split(',')
d_k.append(np.array([float(x) for x in d_i_str[1:]]))
return d_k
def class_division_main(train_data_k, test_data_k, class_division_file):
train_to_file(train_data_k, class_division_file)
return load_d_k(class_division_file)
我们采用的训练集为:
0,18.5029168459,16.8228372586
0,15.1155598703,7.63485856842
0,6.95647929713,5.11539442509
0,0.597347656694,0.288674350297
0,6.36715034554,5.60396174026
0,15.7424995772,8.58485907293
0,18.0807236204,11.3012359879
0,10.1722692483,4.65541515255
1,47.8757125439,1.16332889424
1,26.0643690758,5.61636639878
1,31.1129222794,17.8499786928
1,23.5410782063,19.9956964598
1,26.6069909527,10.9217740638
1,39.7166244327,10.2806301464
1,49.5181370187,0.375776889409
1,49.2415167242,0.707007895968
2,37.8611287104,31.5632667325
2,27.6685288498,23.2728667808
2,47.4112562524,29.4781641683
2,37.9726890343,36.0084457108
2,37.0258232578,24.4223677169
2,44.5991664403,20.8661377571
2,49.2948779017,39.4635364912
2,32.0020253605,19.3541047235
3,44.2110390945,44.4664083051
3,38.3651365989,42.9632379304
3,35.960724763,44.7193712254
3,26.4532769283,34.8314226887
3,48.9018927779,49.5881169749
3,26.6484189936,31.2731506208
3,41.9097085676,47.9449438728
3,35.331013959,48.0256107217
4,3.81087608968,47.13431803
4,15.8810932052,35.593631285
4,5.46591652097,48.2716875232
4,18.7156249506,42.1656191735
4,13.4111687002,39.336560121
4,8.82628702024,49.9086481899
4,17.6593190139,48.6801651249
4,17.4355973285,45.2334103922
5,2.32851634741,7.54190740531
5,0.192342631621,17.2004066838
5,8.26038228069,29.4464752936
5,15.4799873459,34.2505881121
5,13.3166961224,32.4399699833
5,16.5491387156,18.6037637969
5,10.6585086421,21.9943090312
5,12.1629069128,34.3220738065
测试集为:
0,12.3228402554,2.9405381601
0,14.1333891941,0.65068883671
1,47.7641004258,0.797181091341
1,44.9040483706,3.78286547533
2,33.6360204882,20.0574966127
2,43.5656985063,25.4610595264
3,35.3512989762,40.962956587
3,42.7841966519,43.4279532924
4,19.2747318386,31.3287647744
4,7.48232294478,47.1481142472
5,16.0034786096,31.2251124134
5,0.346491676919,25.5140238588
数据集为二维模式,第一个标号为模式所属的类别,其余两个为模式的坐标。
i_j
分类法采用LMSE算法,判别函数的训练结果如下:
----------------d_0_1----------------
第3088次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.03574661 -0.00627283 0.86694337]
----------------d_0_2----------------
第1484次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.01770862 -0.00584334 0.52596256]
----------------d_0_3----------------
第1177次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.01803347 -0.02400582 0.17017424]
----------------d_0_4----------------
第650次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.00532768 -0.0099107 0.16814849]
----------------d_0_5----------------
第5069次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.05118984 -0.05614292 0.09732183]
----------------d_1_0----------------
第3244次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.03574661 0.00627283 -0.86694337]
----------------d_1_2----------------
第10000次迭代...
第20000次迭代...
第22352次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.06312515 -0.09565371 3.77142462]
----------------d_1_3----------------
第1286次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.00263534 -0.01700837 0.50213301]
----------------d_1_4----------------
第1823次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.00410129 -0.01483631 0.49321121]
----------------d_1_5----------------
第2199次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.02822048 0.00192831 -0.60289858]
----------------d_2_0----------------
第1552次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.01770862 0.00584334 -0.52596256]
----------------d_2_1----------------
第10000次迭代...
第20000次迭代...
第20438次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.06312515 0.09565371 -3.77142462]
----------------d_2_3----------------
第10000次迭代...
第13260次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.03983482 -0.05302738 0.4967981 ]
----------------d_2_4----------------
第860次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.00924338 -0.0073895 0.01622401]
----------------d_2_5----------------
第1067次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.01752869 0.00109051 -0.41037223]
----------------d_3_0----------------
第2266次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.01803347 0.02400582 -0.17017424]
----------------d_3_1----------------
第1289次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.00263534 0.01700837 -0.50213301]
----------------d_3_2----------------
第3888次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.03983482 0.05302738 -0.4967981 ]
----------------d_3_4----------------
第848次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.01818381 -0.0101775 -0.02652464]
----------------d_3_5----------------
第887次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.01823988 0.00124634 -0.42504089]
----------------d_4_0----------------
第439次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.00532768 0.0099107 -0.16814849]
----------------d_4_1----------------
第1578次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.00410129 0.01483631 -0.49321121]
----------------d_4_2----------------
第549次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.00924338 0.0073895 -0.01622401]
----------------d_4_3----------------
第1420次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.01818381 0.0101775 0.02652464]
----------------d_4_5----------------
第10000次迭代...
第20000次迭代...
第30000次迭代...
第37054次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.06526169 0.12942484 -5.54312717]
----------------d_5_0----------------
第5050次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.05118984 0.05614292 -0.09732183]
----------------d_5_1----------------
第2199次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.02822048 -0.00192831 0.60289858]
----------------d_5_2----------------
第1303次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.01752869 -0.00109051 0.41037223]
----------------d_5_3----------------
第1005次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.01823988 -0.00124634 0.42504089]
----------------d_5_4----------------
第10000次迭代...
第20000次迭代...
第30000次迭代...
第37479次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.06526169 -0.12942484 5.54312717]
[12.3228402554, 2.9405381601, 1] 属于模式类0,分类正确
[14.1333891941, 0.65068883671, 1] 属于模式类0,分类正确
[47.7641004258, 0.797181091341, 1] 属于模式类1,分类正确
[44.9040483706, 3.78286547533, 1] 属于模式类1,分类正确
[33.6360204882, 20.0574966127, 1] 属于模式类2,分类正确
[43.5656985063, 25.4610595264, 1] 属于模式类2,分类正确
[35.3512989762, 40.962956587, 1] 属于模式类3,分类正确
[42.7841966519, 43.4279532924, 1] 属于模式类3,分类正确
[19.2747318386, 31.3287647744, 1] 属于IR区
[7.48232294478, 47.1481142472, 1] 属于模式类4,分类正确
[16.0034786096, 31.2251124134, 1] 属于模式类5,分类正确
[0.346491676919, 25.5140238588, 1] 属于模式类5,分类正确
识别正确率为:0.916666666667
训练数据、测试数据、识别结果如下图所示:
i_non_i
分类法--------------------d_0--------------------
第1748次迭代... e_k < 0, 停止迭代,检查XW(k)
XW(k) > 0 不成立,故无解,算法结束
data_0不存在判别函数: .... 48
[18.5029168459, 16.8228372586, 1]
[15.1155598703, 7.63485856842, 1]
[6.95647929713, 5.11539442509, 1]
[0.597347656694, 0.288674350297, 1]
[6.36715034554, 5.60396174026, 1]
[15.7424995772, 8.58485907293, 1]
[18.0807236204, 11.3012359879, 1]
[10.1722692483, 4.65541515255, 1]
[-47.8757125439, -1.16332889424, -1]
[-26.0643690758, -5.61636639878, -1]
[-31.1129222794, -17.8499786928, -1]
[-23.5410782063, -19.9956964598, -1]
[-26.6069909527, -10.9217740638, -1]
[-39.7166244327, -10.2806301464, -1]
[-49.5181370187, -0.375776889409, -1]
[-49.2415167242, -0.707007895968, -1]
[-37.8611287104, -31.5632667325, -1]
[-27.6685288498, -23.2728667808, -1]
[-47.4112562524, -29.4781641683, -1]
[-37.9726890343, -36.0084457108, -1]
[-37.0258232578, -24.4223677169, -1]
[-44.5991664403, -20.8661377571, -1]
[-49.2948779017, -39.4635364912, -1]
[-32.0020253605, -19.3541047235, -1]
[-44.2110390945, -44.4664083051, -1]
[-38.3651365989, -42.9632379304, -1]
[-35.960724763, -44.7193712254, -1]
[-26.4532769283, -34.8314226887, -1]
[-48.9018927779, -49.5881169749, -1]
[-26.6484189936, -31.2731506208, -1]
[-41.9097085676, -47.9449438728, -1]
[-35.331013959, -48.0256107217, -1]
[-3.81087608968, -47.13431803, -1]
[-15.8810932052, -35.593631285, -1]
[-5.46591652097, -48.2716875232, -1]
[-18.7156249506, -42.1656191735, -1]
[-13.4111687002, -39.336560121, -1]
[-8.82628702024, -49.9086481899, -1]
[-17.6593190139, -48.6801651249, -1]
[-17.4355973285, -45.2334103922, -1]
[-2.32851634741, -7.54190740531, -1]
[-0.192342631621, -17.2004066838, -1]
[-8.26038228069, -29.4464752936, -1]
[-15.4799873459, -34.2505881121, -1]
[-13.3166961224, -32.4399699833, -1]
[-16.5491387156, -18.6037637969, -1]
[-10.6585086421, -21.9943090312, -1]
[-12.1629069128, -34.3220738065, -1]
--------------------d_1--------------------
第3180次迭代... e_k < 0, 停止迭代,检查XW(k)
XW(k) > 0 不成立,故无解,算法结束
data_1不存在判别函数: .... 48
[47.8757125439, 1.16332889424, 1]
[26.0643690758, 5.61636639878, 1]
[31.1129222794, 17.8499786928, 1]
[23.5410782063, 19.9956964598, 1]
[26.6069909527, 10.9217740638, 1]
[39.7166244327, 10.2806301464, 1]
[49.5181370187, 0.375776889409, 1]
[49.2415167242, 0.707007895968, 1]
[-18.5029168459, -16.8228372586, -1]
[-15.1155598703, -7.63485856842, -1]
[-6.95647929713, -5.11539442509, -1]
[-0.597347656694, -0.288674350297, -1]
[-6.36715034554, -5.60396174026, -1]
[-15.7424995772, -8.58485907293, -1]
[-18.0807236204, -11.3012359879, -1]
[-10.1722692483, -4.65541515255, -1]
[-37.8611287104, -31.5632667325, -1]
[-27.6685288498, -23.2728667808, -1]
[-47.4112562524, -29.4781641683, -1]
[-37.9726890343, -36.0084457108, -1]
[-37.0258232578, -24.4223677169, -1]
[-44.5991664403, -20.8661377571, -1]
[-49.2948779017, -39.4635364912, -1]
[-32.0020253605, -19.3541047235, -1]
[-44.2110390945, -44.4664083051, -1]
[-38.3651365989, -42.9632379304, -1]
[-35.960724763, -44.7193712254, -1]
[-26.4532769283, -34.8314226887, -1]
[-48.9018927779, -49.5881169749, -1]
[-26.6484189936, -31.2731506208, -1]
[-41.9097085676, -47.9449438728, -1]
[-35.331013959, -48.0256107217, -1]
[-3.81087608968, -47.13431803, -1]
[-15.8810932052, -35.593631285, -1]
[-5.46591652097, -48.2716875232, -1]
[-18.7156249506, -42.1656191735, -1]
[-13.4111687002, -39.336560121, -1]
[-8.82628702024, -49.9086481899, -1]
[-17.6593190139, -48.6801651249, -1]
[-17.4355973285, -45.2334103922, -1]
[-2.32851634741, -7.54190740531, -1]
[-0.192342631621, -17.2004066838, -1]
[-8.26038228069, -29.4464752936, -1]
[-15.4799873459, -34.2505881121, -1]
[-13.3166961224, -32.4399699833, -1]
[-16.5491387156, -18.6037637969, -1]
[-10.6585086421, -21.9943090312, -1]
[-12.1629069128, -34.3220738065, -1]
--------------------d_2--------------------
第148次迭代... e_k < 0, 停止迭代,检查XW(k)
XW(k) > 0 不成立,故无解,算法结束
data_2不存在判别函数: .... 48
[37.8611287104, 31.5632667325, 1]
[27.6685288498, 23.2728667808, 1]
[47.4112562524, 29.4781641683, 1]
[37.9726890343, 36.0084457108, 1]
[37.0258232578, 24.4223677169, 1]
[44.5991664403, 20.8661377571, 1]
[49.2948779017, 39.4635364912, 1]
[32.0020253605, 19.3541047235, 1]
[-18.5029168459, -16.8228372586, -1]
[-15.1155598703, -7.63485856842, -1]
[-6.95647929713, -5.11539442509, -1]
[-0.597347656694, -0.288674350297, -1]
[-6.36715034554, -5.60396174026, -1]
[-15.7424995772, -8.58485907293, -1]
[-18.0807236204, -11.3012359879, -1]
[-10.1722692483, -4.65541515255, -1]
[-47.8757125439, -1.16332889424, -1]
[-26.0643690758, -5.61636639878, -1]
[-31.1129222794, -17.8499786928, -1]
[-23.5410782063, -19.9956964598, -1]
[-26.6069909527, -10.9217740638, -1]
[-39.7166244327, -10.2806301464, -1]
[-49.5181370187, -0.375776889409, -1]
[-49.2415167242, -0.707007895968, -1]
[-44.2110390945, -44.4664083051, -1]
[-38.3651365989, -42.9632379304, -1]
[-35.960724763, -44.7193712254, -1]
[-26.4532769283, -34.8314226887, -1]
[-48.9018927779, -49.5881169749, -1]
[-26.6484189936, -31.2731506208, -1]
[-41.9097085676, -47.9449438728, -1]
[-35.331013959, -48.0256107217, -1]
[-3.81087608968, -47.13431803, -1]
[-15.8810932052, -35.593631285, -1]
[-5.46591652097, -48.2716875232, -1]
[-18.7156249506, -42.1656191735, -1]
[-13.4111687002, -39.336560121, -1]
[-8.82628702024, -49.9086481899, -1]
[-17.6593190139, -48.6801651249, -1]
[-17.4355973285, -45.2334103922, -1]
[-2.32851634741, -7.54190740531, -1]
[-0.192342631621, -17.2004066838, -1]
[-8.26038228069, -29.4464752936, -1]
[-15.4799873459, -34.2505881121, -1]
[-13.3166961224, -32.4399699833, -1]
[-16.5491387156, -18.6037637969, -1]
[-10.6585086421, -21.9943090312, -1]
[-12.1629069128, -34.3220738065, -1]
--------------------d_3--------------------
第1207次迭代... e_k < 0, 停止迭代,检查XW(k)
XW(k) > 0 不成立,故无解,算法结束
data_3不存在判别函数: .... 48
[44.2110390945, 44.4664083051, 1]
[38.3651365989, 42.9632379304, 1]
[35.960724763, 44.7193712254, 1]
[26.4532769283, 34.8314226887, 1]
[48.9018927779, 49.5881169749, 1]
[26.6484189936, 31.2731506208, 1]
[41.9097085676, 47.9449438728, 1]
[35.331013959, 48.0256107217, 1]
[-18.5029168459, -16.8228372586, -1]
[-15.1155598703, -7.63485856842, -1]
[-6.95647929713, -5.11539442509, -1]
[-0.597347656694, -0.288674350297, -1]
[-6.36715034554, -5.60396174026, -1]
[-15.7424995772, -8.58485907293, -1]
[-18.0807236204, -11.3012359879, -1]
[-10.1722692483, -4.65541515255, -1]
[-47.8757125439, -1.16332889424, -1]
[-26.0643690758, -5.61636639878, -1]
[-31.1129222794, -17.8499786928, -1]
[-23.5410782063, -19.9956964598, -1]
[-26.6069909527, -10.9217740638, -1]
[-39.7166244327, -10.2806301464, -1]
[-49.5181370187, -0.375776889409, -1]
[-49.2415167242, -0.707007895968, -1]
[-37.8611287104, -31.5632667325, -1]
[-27.6685288498, -23.2728667808, -1]
[-47.4112562524, -29.4781641683, -1]
[-37.9726890343, -36.0084457108, -1]
[-37.0258232578, -24.4223677169, -1]
[-44.5991664403, -20.8661377571, -1]
[-49.2948779017, -39.4635364912, -1]
[-32.0020253605, -19.3541047235, -1]
[-3.81087608968, -47.13431803, -1]
[-15.8810932052, -35.593631285, -1]
[-5.46591652097, -48.2716875232, -1]
[-18.7156249506, -42.1656191735, -1]
[-13.4111687002, -39.336560121, -1]
[-8.82628702024, -49.9086481899, -1]
[-17.6593190139, -48.6801651249, -1]
[-17.4355973285, -45.2334103922, -1]
[-2.32851634741, -7.54190740531, -1]
[-0.192342631621, -17.2004066838, -1]
[-8.26038228069, -29.4464752936, -1]
[-15.4799873459, -34.2505881121, -1]
[-13.3166961224, -32.4399699833, -1]
[-16.5491387156, -18.6037637969, -1]
[-10.6585086421, -21.9943090312, -1]
[-12.1629069128, -34.3220738065, -1]
--------------------d_4--------------------
第10000次迭代...
第20000次迭代...
第30000次迭代...
第40000次迭代...
第50000次迭代...
第52655次迭代... e_k < 0, 停止迭代,检查XW(k)
XW(k) > 0 不成立,故无解,算法结束
data_4不存在判别函数: .... 48
[3.81087608968, 47.13431803, 1]
[15.8810932052, 35.593631285, 1]
[5.46591652097, 48.2716875232, 1]
[18.7156249506, 42.1656191735, 1]
[13.4111687002, 39.336560121, 1]
[8.82628702024, 49.9086481899, 1]
[17.6593190139, 48.6801651249, 1]
[17.4355973285, 45.2334103922, 1]
[-18.5029168459, -16.8228372586, -1]
[-15.1155598703, -7.63485856842, -1]
[-6.95647929713, -5.11539442509, -1]
[-0.597347656694, -0.288674350297, -1]
[-6.36715034554, -5.60396174026, -1]
[-15.7424995772, -8.58485907293, -1]
[-18.0807236204, -11.3012359879, -1]
[-10.1722692483, -4.65541515255, -1]
[-47.8757125439, -1.16332889424, -1]
[-26.0643690758, -5.61636639878, -1]
[-31.1129222794, -17.8499786928, -1]
[-23.5410782063, -19.9956964598, -1]
[-26.6069909527, -10.9217740638, -1]
[-39.7166244327, -10.2806301464, -1]
[-49.5181370187, -0.375776889409, -1]
[-49.2415167242, -0.707007895968, -1]
[-37.8611287104, -31.5632667325, -1]
[-27.6685288498, -23.2728667808, -1]
[-47.4112562524, -29.4781641683, -1]
[-37.9726890343, -36.0084457108, -1]
[-37.0258232578, -24.4223677169, -1]
[-44.5991664403, -20.8661377571, -1]
[-49.2948779017, -39.4635364912, -1]
[-32.0020253605, -19.3541047235, -1]
[-44.2110390945, -44.4664083051, -1]
[-38.3651365989, -42.9632379304, -1]
[-35.960724763, -44.7193712254, -1]
[-26.4532769283, -34.8314226887, -1]
[-48.9018927779, -49.5881169749, -1]
[-26.6484189936, -31.2731506208, -1]
[-41.9097085676, -47.9449438728, -1]
[-35.331013959, -48.0256107217, -1]
[-2.32851634741, -7.54190740531, -1]
[-0.192342631621, -17.2004066838, -1]
[-8.26038228069, -29.4464752936, -1]
[-15.4799873459, -34.2505881121, -1]
[-13.3166961224, -32.4399699833, -1]
[-16.5491387156, -18.6037637969, -1]
[-10.6585086421, -21.9943090312, -1]
[-12.1629069128, -34.3220738065, -1]
--------------------d_5--------------------
第534次迭代... e_k < 0, 停止迭代,检查XW(k)
XW(k) > 0 不成立,故无解,算法结束
data_5不存在判别函数: .... 48
[2.32851634741, 7.54190740531, 1]
[0.192342631621, 17.2004066838, 1]
[8.26038228069, 29.4464752936, 1]
[15.4799873459, 34.2505881121, 1]
[13.3166961224, 32.4399699833, 1]
[16.5491387156, 18.6037637969, 1]
[10.6585086421, 21.9943090312, 1]
[12.1629069128, 34.3220738065, 1]
[-18.5029168459, -16.8228372586, -1]
[-15.1155598703, -7.63485856842, -1]
[-6.95647929713, -5.11539442509, -1]
[-0.597347656694, -0.288674350297, -1]
[-6.36715034554, -5.60396174026, -1]
[-15.7424995772, -8.58485907293, -1]
[-18.0807236204, -11.3012359879, -1]
[-10.1722692483, -4.65541515255, -1]
[-47.8757125439, -1.16332889424, -1]
[-26.0643690758, -5.61636639878, -1]
[-31.1129222794, -17.8499786928, -1]
[-23.5410782063, -19.9956964598, -1]
[-26.6069909527, -10.9217740638, -1]
[-39.7166244327, -10.2806301464, -1]
[-49.5181370187, -0.375776889409, -1]
[-49.2415167242, -0.707007895968, -1]
[-37.8611287104, -31.5632667325, -1]
[-27.6685288498, -23.2728667808, -1]
[-47.4112562524, -29.4781641683, -1]
[-37.9726890343, -36.0084457108, -1]
[-37.0258232578, -24.4223677169, -1]
[-44.5991664403, -20.8661377571, -1]
[-49.2948779017, -39.4635364912, -1]
[-32.0020253605, -19.3541047235, -1]
[-44.2110390945, -44.4664083051, -1]
[-38.3651365989, -42.9632379304, -1]
[-35.960724763, -44.7193712254, -1]
[-26.4532769283, -34.8314226887, -1]
[-48.9018927779, -49.5881169749, -1]
[-26.6484189936, -31.2731506208, -1]
[-41.9097085676, -47.9449438728, -1]
[-35.331013959, -48.0256107217, -1]
[-3.81087608968, -47.13431803, -1]
[-15.8810932052, -35.593631285, -1]
[-5.46591652097, -48.2716875232, -1]
[-18.7156249506, -42.1656191735, -1]
[-13.4111687002, -39.336560121, -1]
[-8.82628702024, -49.9086481899, -1]
[-17.6593190139, -48.6801651249, -1]
[-17.4355973285, -45.2334103922, -1]
故采用i_non_i分类法,无法训练出样本数据的判别函数
采用LMSE算法,通过i_non_i
分类法,多次更改算法因子B(1)
、c
,均未训练出样本数据的判别函数。
在已知子类划分的情况下,分段线性判别法的输出结果如下:
------------划分父类_0------------
第3846次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[-0.35746611 -0.0627283 8.66943368]
------------划分父类_1------------
第3948次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.39834821 -0.53027375 4.96798099]
------------划分父类_2------------
第10000次迭代...
第20000次迭代...
第30000次迭代...
第40000次迭代...
第46418次迭代... e_k < 0, 停止迭代,检查XW(k)
线性可分,解为[ 0.6524255 1.29430561 -55.43026674]
[12.3228402554, 2.9405381601, 1] 属于模式类2,分类错误
[14.1333891941, 0.65068883671, 1] 属于模式类2,分类错误
[47.7641004258, 0.797181091341, 1] 属于模式类1,分类错误
[44.9040483706, 3.78286547533, 1] 属于模式类2,分类错误
[33.6360204882, 20.0574966127, 1] 属于模式类1,分类正确
[43.5656985063, 25.4610595264, 1] 属于模式类1,分类正确
[35.3512989762, 40.962956587, 1] 属于模式类2,分类错误
[42.7841966519, 43.4279532924, 1] 属于模式类2,分类错误
[19.2747318386, 31.3287647744, 1] 属于模式类1,分类错误
[7.48232294478, 47.1481142472, 1] 属于模式类1,分类错误
[16.0034786096, 31.2251124134, 1] 属于模式类1,分类错误
[0.346491676919, 25.5140238588, 1] 属于模式类2,分类正确
识别正确率为:0.25
实验结果如下图所示:
如上的四个子图中,train_data
显示了三个父类的训练数据集,class_division
是已知的子类划分方式,discriminant_func
是训练所得的判别函数(3个父类,每个父类有2个子类),test_data
是判别函数对测试集的识别结果。
在本篇博客中,我们学习了判别函数与几何分类法的思想精髓,也对LMSE算法在多分类条件下进行模式识别的整个过程加以实现。核心内容有以下几点:
在上述实践的过程中,值得我们思考的地方还有:
i_non_i
分类法下,并未训练出相应的判别函数,这也反映了i_non_i
与i_j
两种分类法之间的特性:虽然i_j
两分法则比i_non_i
两分法需要训练出更多的判别函数,但其对模式线性可分的可能性更大一点,因为其不需要考虑其他模式类的影响,受到的限制条件少。这是i_j
两分法的主要优点。[1] 《模式识别导论》. 齐敏、李大健、郝重阳编著.
[2]《利用Python进行数据分析》. Wes Mckinney著.