KNN伪代码(简易版和复杂版)

简单来写:
    def fit(train, k):
	  self.train = train
	  self.k = k
	def predict(test):
	  # a. 从训练数据train中获取和当前数据test距离最近的k个样本
	  neighbors = fetch_k_neighbors(self.train, test, self.k)
	  # b. 合并这K个最近样本,得到预测值
	  predict_label = calc_predict_label(neighbors)
	  return predict_label
  
  复杂来写:
    def fit(train, k):
	  self.train = train
	  self.k = k
	def predict(test):
	  result = []
	  for x in test:
	    # a. 从训练数据train中获取和当前数据x距离最近的k个样本
	    neighbors = fetch_k_neighbors(self.train, x, self.k)
		
	    # b. 合并这K个最近样本,得到预测值
		# b1. 统计一下各个类别label出现的次数
		label_2_count_dict = {}
		for neighbor in neighbors:
		  # b11. 获取当前样本neighbor的标签值
		  label = neighbor.label
		  # b12. 将这个label添加到字典中
		  if label not in label_2_count_dict:
		    label_2_count_dict[label] = 1
		  else:
		    label_2_count_dict[label] += 1
		# b2. 从这个字典中获取出现次数最多的label标签值作为预测值
		max_label_count = 0
		max_label = None
		for label in label_2_count_dict:
	      # 获取当前label对应出现的count数量
		  count = label_2_count_dict[label]
		  # 将当前count和最大值进行比较,选择/保留最大的count
		  if count > max_label_count:
		    max_label_count = count
			max_label = label
		# b3. 将预测值添加到集合中
		result.append(max_label)
	  return result

你可能感兴趣的:(KNN,机器学习)