CIFAR-10是一个常用的图像分类数据集。数据集包含60000张32*32像素的小图片,每张图片都有一个类别标注(总共有10类),分成了50000张的训练集和10000张的测试集。
python中提取CIFAR-10的代码如下:
a python2 routine which will open such a file and return a dictionary:
def unpickle(file):
import cPickle
with open(file, 'rb') as fo:
dict = cPickle.load(fo)
return dict
And a python3 version:
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
# -*- coding: utf-8 -*-
import os
import numpy as np
import pickle
def load_CIFAR_batch(filename):
with open(filename, 'rb') as fo:
d= pickle.load(fo, encoding='bytes')
X=d[b'data']
Y=d[b'labels']
X=X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
Y=np.array(Y)
return X, Y
def load_CIFAR10(ROOT):
xs=[]
ys=[]
for b in range(1,6):
f=os.path.join(ROOT, "data_batch_%d" % (b, ))###
X, Y=load_CIFAR_batch(f)
xs.append(X) #add to list, like[array([1, 2, 3]), array([4, 5])]
ys.append(Y)
X_train=np.concatenate(xs) #transform to np,like [1 2 3 4 5]
Y_train=np.concatenate(ys)
del X, Y
X_test, Y_test=load_CIFAR_batch(os.path.join(ROOT, "test_batch"))
return X_train, Y_train, X_test, Y_test
X_train, Y_train, X_test, Y_test = load_CIFAR10('F:\python/cifar-10-batches-py/')
# 把32*32*3的多维数组展平
Xtr_rows = X_train.reshape(X_train.shape[0], 32 * 32 * 3) # Xtr_rows : 50000 x 3072
Xte_rows = X_test.reshape(X_test.shape[0], 32 * 32 * 3) # Xte_rows : 10000 x 3072
class NearestNeighbor:
def __init__(self):
pass
def train(self, X, y):
# the nearest neighbor classifier simply remembers all the training data
self.Xtr = X
self.ytr = y
def predict(self, X,k):
num_test = X.shape[0]
# 要保证维度一致哦
Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
# 把训练集扫一遍 -_-|
for i in range(num_test):
# 计算l1距离,并找到最近的图片
jishu={}
distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
for j in range(k):
min_index = np.argmin(distances) # 取最近图片的下标
max_index=np.argmax(distances)
if self.ytr[min_index] in jishu.keys():
jishu[self.ytr[min_index]]+=1
else:
jishu[self.ytr[min_index]]=1
distances[min_index]=distances[max_index]
Ypred[i]=max(jishu.items(), key=lambda x: x[1])[0]
return Ypred
Xval_rows = Xtr_rows[:1000, :] # 验证集
Yval = Y_train[:1000]
Xtr_rows = Xtr_rows[1000:, :] # 保留49000的训练集
Ytr = Y_train[1000:]
nn = NearestNeighbor() # 初始化一个最近邻对象
nn.train(Xtr_rows,Ytr) # 训练...其实就是读取训练集
for k in [3,5,7,10,20]:
Yte_predict = nn.predict(Xval_rows,k) # 预测
print ('k=%d'%(k),'accuracy: %f' % ( np.mean(Yte_predict == Yval) ))
print("end")