优点:
缺点:
适用的情况:
dating_data_mat = array([[1.0, 1.0], [1.1, 1.2], [1.2, 1.1], [1.3, 0.9], [0.1, 0.2], [0.1, 0.1], [0.0, 0.1], [0.1, 0.3]])
dating_labels = [1, 1, 1, 1, 2, 2, 2, 2]
fig = plt.figure() # create a new figure
ax = fig.add_subplot(111) # 349 代表将整个画布分成3行4列,只使用第9块,111 相当于使用整个画布
# 第0列作为x轴,第1列作为y轴, 根据标签换颜色 15 代表点的大小,1 就非常小
ax.scatter(dating_data_mat[:, 0], dating_data_mat[:, 1], 15*array(dating_labels), 15*array(dating_labels))
plt.show() # 画
这种情况下如果想要新来一个点 (0.3, 0.1) 究竟应该属于哪一个区??
KNN的想法就是,这个点距离哪一个区中的点更近就属于哪一个区
更具体的说法:
import numpy as np
def load_file(file_name):
"""
:param file_name: the name of the file which should be loaded
:type file_name: matrix
:return: the matrices of training set, the label of training set, test data set, the label of data set
:rtype: all of it are matrix
"""
limit = 80
data = []
label = []
test_data = []
test_label = []
fp = open(file_name)
for line in fp.readlines():
temp = line.strip().split('\t')
if limit <= 0:
test_data.append([float(temp[0]), float(temp[1])])
test_label.append([float(temp[2])])
else:
data.append([float(temp[0]), float(temp[1])])
label.append([float(temp[2])])
limit -= 1
return np.mat(data), np.mat(label), np.mat(test_data), np.mat(test_label)
def normalization(distance_of_feature):
"""
:param distance_of_feature: Euclidean distance between the feature of every training set and test sample
:type distance_of_feature: matrix
:return: the normalized Euclidean distance
:rtype: matrix
"""
# for feature in range(0, distance[0, :].size):
max_temp = max(distance_of_feature[:, 0])
min_temp = min(distance_of_feature[:, 0])
range_temp = max_temp - min_temp
for line in range(0, distance_of_feature[:, 0].size):
distance_of_feature[line, 0] /= range_temp
return distance_of_feature
def calculate_distance(data_mat, sample):
"""
:param data_mat: training set
:type data_mat: matrix
:param sample: test sample
:type sample: matrix
:return: the actual Euclidean distance between every data set and sample
:rtype: matrix
"""
dis_list = []
for line in range(0, data_mat[:, 0].size):
dis_temp = []
for feature in range(0, data_mat[0, :].size):
dis_temp.append((data_mat[line, feature] - sample[0, feature])**2)
dis_list.append(dis_temp)
separate_res = np.mat(dis_list)
separate_res = normalization(separate_res)
res = separate_res[:, 0] + separate_res[:, 1]
return res
def find_classification(nor_distance, label_matrix, k):
"""
:param nor_distance: the actual Euclidean distance between every data set and sample
:type nor_distance: matrix
:param label_matrix: the label of training set
:type label_matrix: matrix
:param k: the nearest K training samples between test sample and training set
:type k: int
:return: the classification that are calculated
:rtype: int
"""
temp = np.hstack((nor_distance, label_matrix))
temp_sorted = temp[np.lexsort(temp[:, ::-1].T)] # 按第一列进行排序
counts = 0
for i in range(0, k):
if temp_sorted[0, i, 1] == 1:
counts += 1
if counts > k - counts:
return 1
else:
return -1
data_mat, label_matrix, test_data_mat, test_label_mat = load_file("testSet.txt")
count = 0
for line in range(0, test_data_mat[:, 0].size):
nor_distance = calculate_distance(data_mat, test_data_mat[line])
classification = find_classification(nor_distance, label_matrix, 5)
if classification == test_label_mat[line, 0]:
print("classification(calculated): %d\nclassification(real): %d\n" % (classification, test_label_mat[line, 0]))
count += 1
print("accuracy:%" + str(float(count)/float(20)*100))
大概是由于选择的数据集过于分散,且数量较少,而KNN本身的分类效果就是准确率很高:结果准确率为100%
这里的训练数据80,测试数据20左右
结果:
classification(calculated): 1
classification(real): 1
classification(calculated): 1
classification(real): 1
classification(calculated): 1
classification(real): 1
classification(calculated): 1
classification(real): 1
classification(calculated): -1
classification(real): -1
classification(calculated): -1
classification(real): -1
classification(calculated): -1
classification(real): -1
classification(calculated): -1
classification(real): -1
classification(calculated): 1
classification(real): 1
classification(calculated): -1
classification(real): -1
classification(calculated): 1
classification(real): 1
classification(calculated): 1
classification(real): 1
classification(calculated): 1
classification(real): 1
classification(calculated): -1
classification(real): -1
classification(calculated): -1
classification(real): -1
classification(calculated): -1
classification(real): -1
classification(calculated): -1
classification(real): -1
classification(calculated): -1
classification(real): -1
classification(calculated): -1
classification(real): -1
classification(calculated): -1
classification(real): -1
accuracy:%100.0