CIFAR10中的KNN分类

1. KNN思路

       相对于NN算法而言,KNN的关键在于,通过欧式距离计算待分类图像与训练集图像之间的距离,做距离之间从小到大的排序,找出前K个从小到大出现次数最多的标签作为预测分类标签。具体的示意图如图1所示。

                CIFAR10中的KNN分类_第1张图片

                                                 图1 KNN分类CIFAR10大致思路

2. python代码实现

(1) 数据的提取

import pickle
import operator
import numpy as np
import pandas as pd


# 数据获取
def unpickle(file):
    with open(file,'rb') as f:
        dict = pickle.load(f, encoding='bytes')
    return dict

(2) KNN算法类

# K近邻算法
class KNearestNeighbor:
    def __init__(self):
        pass
    
    def train(self,X,y):
        '''X is size of N x D matrix, Y is 1-dimesion of size N'''
        self.Xtr = X
        self.ytr = y
        
    def predict(self,X,k):
        '''X is N x D matrix where each row is an example we wish to predict label for.
           k is the nearest neighbor algorithm'''
        num = X.shape[0]
        Ypred = np.zeros(num)
        for i in range(num):
            # 利用欧式距离
            distance = np.sum((self.Xtr - X[i,:])**2,axis=1)**0.5 
            # 对距离结果排序,得到从小到大索引
            sortedDistanceIndexs = distance.argsort()
            # k近邻的k循环,统计前k个距离最小的样本
            countDict = {}
            for j in range(k):
                countY = self.ytr[sortedDistanceIndexs[j]] # 得到前k个从小到大索引的样本类别
                countDict[countY] = countDict.get(countY,0) + 1 # 统计出现不存在则为0
            
            # 对前k个距离最小做value排序,找出统计次数最多的类别,作为预测类别
            sortedCountDict = sorted(countDict.items(),key=operator.itemgetter(1),reverse=True)
            Ypred[i] = sortedCountDict[0][0]
        return Ypred

(3) 准确率的计算与运行时间

%%time
# KNN对图像集做分类,计算准确率
top_num = 50
train_data = unpickle('../computer_vision_lifeifei/cifar-10-batches-py/data_batch_5')
test_data = unpickle('../computer_vision_lifeifei/cifar-10-batches-py/test_batch')

knn = KNearestNeighbor()
knn.train(train_data[b'data'],np.array(train_data[b'labels']))
Ypred = knn.predict(test_data[b'data'][:top_num,:],3)

accur = np.sum(np.array(Ypred)==np.array(test_data[b'labels'][:top_num])) / len(Ypred)
print(accur)

输出结果如下:

 

 

你可能感兴趣的:(计算机视觉,图像分类)