系统3-5:
编辑实现线性判别分析,并给出西瓜数据集 3.0α 上的结果.
西瓜数据集3.0α:
sn | density | suger_ratio | good_melon |
1 | 0.697 | 0.46 | 1 |
2 | 0.774 | 0.376 | 1 |
3 | 0.634 | 0.264 | 1 |
4 | 0.608 | 0.318 | 1 |
5 | 0.556 | 0.215 | 1 |
6 | 0.403 | 0.237 | 1 |
7 | 0.481 | 0.149 | 1 |
8 | 0.437 | 0.211 | 1 |
9 | 0.666 | 0.091 | 0 |
10 | 0.243 | 0.267 | 0 |
11 | 0.245 | 0.057 | 0 |
12 | 0.343 | 0.099 | 0 |
13 | 0.639 | 0.161 | 0 |
14 | 0.657 | 0.198 | 0 |
15 | 0.36 | 0.37 | 0 |
16 | 0.593 | 0.042 | 0 |
17 | 0.719 | 0.103 | 0 |
python程序
#encoding:UTF-8
import csv
import numpy as np
from math import *
from numpy.linalg import *
from matplotlib import pyplot as plt
#
wm_data_30a = csv.reader(open('../data_set/watermelon_data_set_30a.csv','r'))
xi_d = np.mat(np.zeros((17,2)))
yi_d = np.mat(np.zeros((17,1)))
sn=0
for stu in wm_data_30a:
#[sn x1 x2 y]
if(stu[0].isdigit()==True):
#xi_d=[x1,x2]
xi_d[sn,:] = np.mat([float(stu[1]),float(stu[2])])
yi_d[sn,0] = float(stu[3])
sn = sn+1
u0=np.mat([[float(0.0)],[float(0.0)]])#2*1
u1=np.mat([[float(0.0)],[float(0.0)]])#2*1
for idx in range(8):#y=1
#print('idx=',idx,xi_d[idx,0])
u1[0,0] = u1[0,0] + xi_d[idx,0]
u1[1,0] = u1[1,0] + xi_d[idx,1]
#print(u1)
u1=u1/8
cov_1 = np.mat([[float(0.0),float(0.0)],[float(0.0),float(0.0)]])
for idx in range(8):
cov_1 = cov_1 + (xi_d[idx,:].T - u1)*(xi_d[idx,:].T - u1).T
print('u1=',u1)
print('cov_1=',cov_1)
for idx in range(8,17):#y=0
u0[0,0] = u0[0,0] + xi_d[idx,0]
u0[1,0] = u0[1,0] + xi_d[idx,1]
u0=u0/9
cov_0 = np.mat([[float(0.0),float(0.0)],[float(0.0),float(0.0)]])
for idx in range(8,17):
cov_0 = cov_0 + (xi_d[idx,:].T-u0)*(xi_d[idx,:].T-u0).T
print('u0=',u0)
print('cov_0=',cov_0)
#sw= cov_0+cov_1
sw= cov_0+cov_1
print('sw=',sw)
print('sw_inv=',sw.I)
print('u0-u1=',u0-u1)
w=(sw.I)*(u0-u1)
print('final w=',w)
for idx in range(17):
if yi_d[idx,0]==1:
plt.plot(xi_d[idx,0],xi_d[idx,1],'+r')
else:
plt.plot(xi_d[idx,0],xi_d[idx,1],'ob')
#
ply=-(0.1*w[0,0] - 0.01)/w[1,0];
pry=-(0.9*w[0,0] - 0.01)/w[1,0];
px=[0.1,0.9]
py=[ply,pry]
plt.plot(px,py)
plt.xlabel('density')
plt.ylabel('suger ratio')
plt.title('logistic function regression')
plt.show()
运行结果:
u1= [[0.57375]
[0.27875]]
cov_1= [[0.1168675 0.0718285]
[0.0718285 0.0712995]]
u0= [[0.49611111]
[0.15422222]]
cov_0= [[ 0.30332289 -0.05006522]
[-0.05006522 0.09295756]]
sw= [[0.42019039 0.02176328]
[0.02176328 0.16425706]]
sw_inv= [[ 2.39631815 -0.31750075]
[-0.31750075 6.13008588]]
u0-u1= [[-0.07763889]
[-0.12452778]]
final w= [[-0.14650982]
[-0.73871557]]