Demo
# -*- coding: utf-8 -*-
'''
Project: KNNDemo
File: demo.py
Author: Weifu Liu
Time:2020.06.12
'''
import numpy as np
import operator
import matplotlib.pyplot as plt
def createDataSet():
'''
Input:None
Output:Group--a tran-dataset
ClassLabelVector--the label of class used to plot scatter color
Function:Prepare data
'''
group = np.array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
classLabelVector = []
for label in labels:
if label == 'A':
classLabelVector.append(1)
if label == 'B':
classLabelVector.append(2)
return group,classLabelVector
def createScatter(dataset,classLabelVector):
'''
Input: Sample Dataset
Output: A Scatter Plot
Function:Plot a scatter plot from sample dataset
'''
print(dataset)
plt.scatter(dataset[:,0], dataset[:,1],alpha=0.6,c=classLabelVector)
plt.ylabel('X')
plt.xlabel('y')
plt.title("KNN Demo")
plt.show()
def classify0(inX,dataSet,labels,k):
'''
Input: inX - test vector
dataset -test-dataset
labels - categories
k - parameter K of KNN algorithm
Output:result of KNN
'''
dataSetSize = dataSet.shape[0]
print(dataSetSize)
# diffMat is the value of x[i]-x[j] and y[i]-y[j]
diffMat = np.tile(inX,(dataSetSize,1)) - dataSet
print(diffMat)
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
print(sqDistances)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount = {}
for i in range(k):
voteILabel = labels[sortedDistIndicies[i]]
classCount[voteILabel] = classCount.get(voteILabel,0) + 1
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def main():
group,classLabelVector = createDataSet()
print(group[0])
createScatter(group,classLabelVector)
testVec = np.array([1.2,1.1])
k = 3
print("测试数据的类别为:",classify0(testVec, group, classLabelVector, k))
if __name__ == "__main__":
main()
import numpy as np
import operator
import matplotlib.pyplot as plt
import csv
import random
def PreProcess(path):
'''
Input: path - the file path of dataset
Output:None
Function: get the tran_data and test_data
'''
f = open(path,'r',encoding='utf8')
tran_csv = open('data/tran_data.csv','w',encoding='utf8',newline='')
test_csv = open('data/test_data.csv','w',encoding='utf8',newline='')
tran_writer = csv.writer(tran_csv)
test_writer = csv.writer(test_csv)
reader = csv.reader(f)
for i in reader:
print(i)
rand = random.randint(1,5)
lst = []
lst.append(i[0])
lst.append(i[1])
lst.append(i[2])
lst.append(i[3])
if i[4] == 'Iris-setosa':
lst.append('1')
if i[4] == 'Iris-versicolor':
lst.append('2')
if i[4] == 'Iris-virginica':
lst.append('3')
if rand == 1:
test_writer.writerow(lst)
else:
tran_writer.writerow(lst)
def LoadDataset():
'''
Input:None
Output:tran_data,test_data
Function:Load tran_data and test_data
'''
tran_data = np.loadtxt('data/tran_data.csv',delimiter=',')
test_data = np.loadtxt('data/test_data.csv',delimiter=',')
#print(dataset)
return tran_data,test_data
def calDistance(inX,dataset,k):
'''
Input:Inx - test vector
dataset - tran_data
k - KNN K value
Output:Predicted class
Function:Calculate the distance to get the distance matrix
'''
repeatTime = dataset.shape[0]
diffMat = np.tile(inX,(repeatTime,1)) - dataset
#print(diffMat)
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1) - sqDiffMat[:,4]
distances = sqDistances**0.5
#print(distances)
return KNNcls(distances,k,dataset)
def KNNcls(distances,k,dataset):
'''
Input: distances - the distance matrix
k - KNN K value
dataset - tran_data
Output:Predicted class
Function:perform KNN classification to get the prediction result
'''
sortedDistanceIndex = distances.argsort()
#print(sortedDistanceIndex)
voteCount = {}
Labels ={"1":"Iris-setosa","2":"Iris-versicolor","3":"Iris-virginica"}
for i in range(k):
labelID = dataset[sortedDistanceIndex[i]][4]
#print("class ID is:",int(labelID))
labelName = Labels[str(int(labelID))]
#print("class label is:",labelName)
voteCount[labelName] = voteCount.get(labelName,0) + 1
#print(voteCount)
sortedVoteCount = sorted(voteCount.items(),key=operator.itemgetter(1),reverse=True)
#print('predicted = ',sortedVoteCount[0][0])
return sortedVoteCount[0][0]
def main():
'''
Input:None
Output:None
Function:Use the classifier to test the data set for accuracy
'''
#path = 'data/iris.csv'
#PreProcess(path)
tran_data,test_data = LoadDataset()
k = 3
Labels ={"1":"Iris-setosa","2":"Iris-versicolor","3":"Iris-virginica"}
sumCount = 0
correctCount = 0
for inX in test_data:
print(inX)
actualLabel = Labels[str(int(inX[4]))]
print('actual = ',actualLabel)
predictedLabel = calDistance(inX,tran_data,k)
print('predicted = ',predictedLabel)
sumCount += 1
if actualLabel == predictedLabel:
correctCount += 1
print('Accuracy:',correctCount/sumCount)
#print(test_data)
inY = [4.9, 3.0 , 1.4, 0.2,2.0]
preLabel = calDistance(inY,tran_data,k)
print('predicted = ',preLabel)
if __name__ == "__main__":
main()