参考的同学的博客:
https://blog.csdn.net/Willen_/article/details/89288218
心得感悟:
同学实现的时候画出来的图有些不对劲,即样本点在LDA线上的垂点位置不对,其实他应该买个正方形的显示器?
以下是实现代码:
import numpy as np
import matplotlib.pyplot as plt
def load_data(file_name):
''' data import function
input: file_name(string) location of training data
output: feature_data(mat) feature
label_data(mat) label
'''
fr = open(file_name)
feature_data =[]
label_data = []
for line in fr.readlines():
curLine = []
lineArr = line.split('\t')
for i in range(0,3):
if i < 2:
curLine.append(float(lineArr[i]))
if i == 1:
feature_data.append(curLine)
else:
tempLine = []
tempLine.append(int(lineArr[i]))
label_data.append(tempLine)
fr.close()
feature_array = np.array(feature_data, dtype = float)
label_array = np.array(label_data, dtype = int)
return feature_array, label_array
def LDA(x1, x2):
''' LDA function
input: x1(array) data of class 1
x2(array) data of class 2
output: w(mat) parameter of the LDA line
'''
u1 = np.mean(x1, axis=0)
u2 = np.mean(x2, axis=0)
Sw = np.dot((x1-u1).T, (x1-u1)) + np.dot((x2-u2).T, (x2-u2))
Swmat = np.mat(Sw)
w = np.dot(Swmat.I, (u1-u2))
return w
测试代码:
if __name__ == "__main__":
# 1. import data
print("-----1. load data-----")
feature_data, label_data = load_data("train_data.txt")
x1 = []
x2 = []
for i in range(0, len(feature_data)):
if label_data[i] == 0:
x1.append(feature_data[i])
elif label_data[i] == 1:
x2.append(feature_data[i])
x1 = np.array(x1)
x2 = np.array(x2)
w = LDA(x1, x2)
print(w)
# 2. plot the figure
print("-----2. plot the figure-----")
x_range = range(-5, 5)
rate = w[0,1]/w[0,0]
# slope of vertical line
rateVL = -1.0 / rate
# rateVL = - w[0,0]/w[0,1]
y_range = [x * rate for x in x_range]
# x2 = kx1 + b => b = x2 - kx1
b1 = x1[:, 1] - x1[:, 0] * rateVL
b2 = x2[:, 1] - x2[:, 0] * rateVL
# calculate the point of intersection
x1_PI = b1 / (rate - rateVL)
y1_PI = rateVL * x1_PI + b1
x2_PI = b2 / (rate - rateVL)
y2_PI = rateVL * x2_PI + b2
plt.plot(x_range, y_range)
# plt.xlim([-10,10])
# plt.ylim([-10,10])
print("x1")
print(type(x1))
print(x1)
print("feature_data")
print(type(feature_data))
print(feature_data)
# plot points of class 1
plt.scatter(x1[:, 0], x1[:, 1], s = 10, c = 'b')
# plot points of class 2
plt.scatter(x2[:, 0], x2[:, 1], s = 10, c = 'r')
plt.scatter(x1_PI, y1_PI, s = 10, c = 'b')
plt.scatter(x2_PI, y2_PI, s = 10, c = 'r')
plt.title('LDA', fontsize=24)
plt.xlabel('feature 1', fontsize=14)
plt.ylabel('feature 2', fontsize=14)
plt.tick_params(axis='both', which='major', labelsize=14)
plt.show()
画图不对的原因:
数据集:https://download.csdn.net/download/thisismykungfu/11136541