KNN例子

参考CS231n,将KNN 跑起来了,成功将系统搞死,,内存和计算能力开销太大。

以下代码 切记不用轻易跑。。


数据集:

http://www.cs.toronto.edu/~kriz/cifar.html


code:


import os

import sys

import numpy as np

import pickle

def load_CIFAR_batch(filename):

"""

cifar-10数据集是分batch存储的,这是载入单个batch

@参数 filename: cifar文件名

@r返回值: X, Y: cifar batch中的 data 和 labels

"""

with open(filename,"rb") as f :

datadict = pickle.load(f,encoding='iso-8859-1')

print(filename)

X=datadict['data']

Y=datadict['labels']

X=X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")

Y=np.array(Y)

return X, Y

def load_CIFAR10(ROOT):

"""

读取载入整个 CIFAR-10 数据集

@参数 ROOT: 根目录名

@return: X_train, Y_train: 训练集 data 和 labels

X_test, Y_test: 测试集 data 和 labels

"""

xs=[]

ys=[]

for b in range(1,6):

f=os.path.join(ROOT, "data_batch_%d" % (b, ))

X, Y=load_CIFAR_batch(f)

xs.append(X)

ys.append(Y)

X_train=np.concatenate(xs)

Y_train=np.concatenate(ys)

del X, Y

X_test, Y_test=load_CIFAR_batch(os.path.join(ROOT, "test_batch"))

return X_train, Y_train, X_test, Y_test

# 载入训练和测试数据集

X_train, Y_train, X_test, Y_test = load_CIFAR10('data/cifar/')

# 把32*32*3的多维数组展平

Xtr_rows = X_train.reshape(X_train.shape[0], 32 * 32 * 3) # Xtr_rows : 50000 x 3072

Xte_rows = X_test.reshape(X_test.shape[0], 32 * 32 * 3) # Xte_rows : 10000 x 3072

class NearestNeighbor:

def __init__(self):

pass

def train(self, X, y):

"""

这个地方的训练其实就是把所有的已有图片读取进来 -_-||

"""

# the nearest neighbor classifier simply remembers all the training data

self.Xtr = X

self.ytr = y

def predict(self, X):

"""

所谓的预测过程其实就是扫描所有训练集中的图片,计算距离,取最小的距离对应图片的类目

"""

num_test = X.shape[0]

# 要保证维度一致哦

Ypred = np.zeros(num_test, dtype = self.ytr.dtype)

# 把训练集扫一遍 -_-||

for i in range(num_test):

# 计算l1距离,并找到最近的图片

distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)

min_index = np.argmin(distances) # 取最近图片的下标

Ypred[i] = self.ytr[min_index] # 记录下label

return Ypred

nn = NearestNeighbor() # 初始化一个最近邻对象

nn.train(Xtr_rows, Y_train) # 训练...其实就是读取训练集

Yte_predict = nn.predict(Xte_rows) # 预测

# 比对标准答案,计算准确率

print ('accuracy: %f' % ( np.mean(Yte_predict == Y_test)))

你可能感兴趣的:(KNN例子)