西瓜书P60 线性判别器LDA代码实现:
import numpy as np
import matplotlib.pyplot as plt
def load_data(file_name):
'''
数据导入函数
:param file_name: (string)训练数据位置
:return: feature_data(mat)特征
lable_data(mat)标签
'''
fr = open(file_name)
feature_data =[];
lable_data = [];
for line in fr.readlines():
curLine = []
lineArr = line.split('\t')
for i in range(0,2):
curLine.append(float(lineArr[i]))
feature_data.append(curLine)
if len(lineArr)<3:
continue;
tempLine = []
for i in range(2,3):
tempLine.append(int(lineArr[i]))
lable_data.append(tempLine)
feature_mat = np.array(feature_data,dtype=float)
label_mat = np.array(lable_data,dtype=int)
fr.close()
return feature_mat,label_mat
def LDA(x1,x2):
'''
:param x1: 类别1 (num,d)
:param x2: 类别2 (num,d)
:return: 投影向量w
'''
u1 = np.mean(x1,axis=0)
u2 = np.mean(x2,axis=0)
Sw = np.dot((x1 - u1).T,(x1-u1)) + np.dot((x2 - u2).T,(x2-u2))
Swmat = np.mat(Sw)
w = np.dot(Swmat.I,(u1-u2))
return w
def LDA_2(x1,x2):
u1 = np.mean(x1, axis=0)
u2 = np.mean(x2, axis=0)
s1 = 0
s2 = 0
for i in range(0,len(x1)):
s1 = s1 + np.dot((x1[i,:] - u1).T,(x1[i,:] - u1))
for i in range(0,len(x2)):
s2 = s2 + np.dot((x2[i,:] - u2).T ,(x2[i,:] - u2))
Sw = s1+s2
Sw = np.mat(Sw)
w = Sw.I *(u1 - u2).T
return w
if __name__ == "__main__":
# 1. 导入数据
print("------1. load data------")
feature_data, lable_data = load_data("train_data.txt")
x1_data =[]
x2_data = []
for i in range(0,len(feature_data)):
if(lable_data[i] == 1):
x1_data.append(feature_data[i])
else:
x2_data.append(feature_data[i])
x1_data = np.array(x1_data)
x2_data = np.array(x2_data)
w = LDA(x1_data,x2_data)
# w = LDA_2(x1_data, x2_data)
print(w)
"""使用scatter()绘制散点图"""
x_values = range(-5, 5)
#映射适量的xielv
rate = w[0,1]/w[0,0]
#对应垂直于这个直线的斜率
rate2 = - w[0,0]/w[0,1]
y_values = [rate * x for x in x_values]
'''
scatter()
x:横坐标 y:纵坐标 s:点的尺寸
'''
#y = kx+b 中的b
b = feature_data[:, 1] - feature_data[:, 0]*rate2
# b = feature_data[:, 0] - feature_data[:, 1]*rate2
#计算出焦点x
x_ = b/(rate-rate2)
#计算出焦点y
y_ = x_*rate2+b
plt.plot(x_values, y_values)
plt.scatter(feature_data[:,0],feature_data[:,1] , s=10)
w = np.array(w)
# aa =np.squeeze()
plt.scatter(x_,y_ , s=10)
# plt.scatter(x_values, y_values, c=y_values, cmap=plt.cm.Blues, edgecolors='none', s=10)
# 设置图表标题并给坐标轴加上标签
plt.title('LDA', fontsize=24)
plt.xlabel('feature_1', fontsize=14)
plt.ylabel('feature_2', fontsize=14)
# 设置刻度标记的大小
plt.tick_params(axis='both', which='major', labelsize=14)
plt.show()
最终效果
数据集:https://download.csdn.net/download/willen_/11110641