目录
创建数据集
自写版KNN算法
优化版KNN算法
效果可视化
创建一个电影分类数据集
接吻次数 | 打斗次数 | 电影类型 |
3 | 100 | 动作片 |
1 | 90 | 动作片 |
2 | 81 | 动作片 |
101 | 10 | 爱情片 |
99 | 5 | 爱情片 |
98 | 2 | 爱情片 |
代码实现:
def Dataset():
data = np.array([[3, 100], [1, 90], [2, 81], [101, 10], [99, 5], [98, 2]])
labels = [ '动作片', '动作片', '动作片', '爱情片', '爱情片', '爱情片']
return data, labels
代码实现:
def Knn(in_data, train_data, train_labels, k):
# 计算欧式距离
distance = np.zeros(train_data.shape[0])
for i in range(train_data.shape[0]):
distance[i] = (in_data[0] - train_data[i][0]) ** 2 +
(in_data[1] - train_data[i][1]) ** 2
#开方处理
distance[i] = np.power(distance[i], 0.5)
# 返回按距离排序的索引
index = np.zeros(train_data.shape[0])
index = distance.argsort()
# 统计前k个最小距离对应的标签个数
love = 0
action = 0
for i in range(k):
if(train_labels[index[i]] == '爱情片'):
love += 1
else:
action += 1
if (love > action):
print('该电影类型为爱情片')
else:
print('该电影类型为动作片')
KNN算法原理:
将输入数据(x, y)与数据集中的数据(xi,yi)分别计算欧氏距离,将欧氏距离按照递增排列,统计前k个距离中对应的标签的个数,输入数据的标签即为k个距离中标签个数最多的那个。
欧氏距离计算:
计算输入数据与数据集中每个数据的欧式距离,并按照递增排序,由于我们只需得到最小的k个距离对应的索引即可(方便后续按照索引寻找到对应的标签),因此利用argsort()函数,该函数将数列排序,并返回原来索引。
统计前k个距离对应标签类型的个数:我自己写的时候没有想到简便的方法,于是创建了两个变量love和actor分别统计标签爱情片和动作片出现的个数,个数最多的标签即为输入数据对应的类型。
测试:当输入数据为in_data = [10, 50]时,结果如下:
实现代码:
def KNN(in_data, train_data, train_labels, k):
train_data_size = train_data.shape[0]
# 将输入数据平铺为train_data_size行1列,便于与训练数据做差
distance = (np.tile(in_data, (train_data_size, 1)) - train_data) ** 2
add_distance = distance.sum(axis=1)
sq_distance = add_distance ** 0.5 # 欧氏距离
# 将欧氏距离排序,返回对应的索引值
index = sq_distance.argsort()
classdict = {}
# 寻找前k个最小距离对应的标签
for i in range(k):
vote_label = train_labels[index[i]] # 第i个距离对应的标签
classdict[vote_label] = classdict.get(vote_label, 0) + 1 #统计某个标签个数
sort_classdict = sorted(classdict.items(), key=operator.itemgetter(1), reverse=True)
return sort_classdict[0][0]
几个优化点:
代码实现:
def data_show(in_data, train_data):
# 显示训练数据
x = []
y = []
for i in range(train_data.shape[0]):
x.append(train_data[i][0])
y.append(train_data[i][1])
plt.plot(x, y, "*")
plt.xlabel("Number of kisses")
plt.ylabel("Number of fights")
plt.plot(in_data[0], in_data[1], "r*")
plt.show()
横轴代表接吻次数,纵轴代表打斗次数。由结果可发现当输入数据为in_data=[10, 50]时,该电影为动作片,即图中红色点。