在数据挖掘学习札记:KNN算法(一)里,使用sklearn模块对例子进行了求解,但是并不清楚k的取值。
下面是我写的一个Python代码,程序采用“小题大做”的方式,一方面可以熟悉算法,另一方面练习Python编程,可以看到,当k取1,2,3,4,5时,knn算法预测未知电影的类型都是R,即Romance。
说明:
1. 距离使用欧氏距离;
2. k近邻搜索使用线性扫描;
3. 未知电影对象调用label方法,得到预测类型;
from math import sqrt class Movie: ''' Represents a movie ''' lib=[] total=0 # Number of movies def __init__(self,nk,nf,tg): ''' Initialize a movie ''' self.nkiss=nk self.nfight=nf self.tag=tg self.index=len(self.lib)+1 Movie.lib.append(self)# initialize a movie and add it to the library Movie.total+=1 self.print() def distance(self,movie): ''' Distance to another movie ''' if type(movie)!=type(self): raise TypeError('requires a %s, given a %s'% (type(self),type(movie))) else: dis=(self.nkiss-movie.nkiss)**2 dis=dis+(self.nfight-movie.nfight)**2 dis=sqrt(dis) return dis def get(self,k): ''' Get the kth movie''' if k>self.total: raise IndexError('out of range') else: for m in self.lib: if k==m.index: return m def neighbors(self,k): '''Find its k neighbors ''' dis=[] movie_many=[] for movie in self.lib: dis.append((movie.index,self.distance(movie))) dis.sort(key=lambda dis:dis[1]) # sort according to distances for i in range(1,k+1): movie_many.append(self.get(dis[i][0])) self.print(movie_many) return movie_many def print(self,movies=None): ''' Print the information of a movie or a set of movies ''' if movies==None: print((self.index,self.nkiss,self.nfight,self.tag)) else: for m in movies: m.print() def label(self,k=1): '''From its k nearest neihbors to determin its tag: R or A ?''' if self.tag=='unknown': movie_many=self.neighbors(k) nR=0 nA=0 for movie in movie_many: if movie.tag=='R': nR+=1 elif movie.tag=='A': nA+=1 else: raise TypeError('The movie with label %d is not a training data'%movie.label()) else: if nR>nA: tag='R' elif nR<nR: tag='A' else: tag='unknown' return tag Movie(3,104,'R') Movie(2,100,'R') Movie(1,81,'R') Movie(101,10,'A') Movie(99,5,'A') Movie(98,2,'A') test=Movie(18,90,'unknown')
计算结果如下:
>>> test.label(1) (2, 2, 100, 'R') 'R' >>> test.label(2) (2, 2, 100, 'R') (3, 1, 81, 'R') 'R' >>> test.label(3) (2, 2, 100, 'R') (3, 1, 81, 'R') (1, 3, 104, 'R') 'R' >>> test.label(4) (2, 2, 100, 'R') (3, 1, 81, 'R') (1, 3, 104, 'R') (4, 101, 10, 'A') 'R' >>> test.label(5) (2, 2, 100, 'R') (3, 1, 81, 'R') (1, 3, 104, 'R') (4, 101, 10, 'A') (5, 99, 5, 'A') 'R' >>> test.label(6) (2, 2, 100, 'R') (3, 1, 81, 'R') (1, 3, 104, 'R') (4, 101, 10, 'A') (5, 99, 5, 'A') (6, 98, 2, 'A') 'unknown' >>>
上面的运行如下解释,
首先,初始化6个电影对象,代表6个实例,第7个是未知电影test,它的类型tag未知。然后调用label方法, test.label(2),表示利用test的最近的两个邻居来预测,test电影的类型。在运行中,我们打印出了,该未知电影的k个近邻。
知识点:
1. 使用sort进行排序时,可以指定key,使得按指定方式排序,上面的代码key=lambda dis:dis[1]使用每个元素(二元组)的第二个为键进行排序,因为第二元组的第一个是电影的序号,第二个是距离;
2. 虽然Python没有函数重载,但是可以使用默认参数达到同样的目的,比如print函数,既可以打印一个电影,也可以打印一个电影列表。关于Python的函数重载问题,网上有评论,比如这里。