CIFAR-10 最近邻分类识别 python3 NearestNeighbor

用到的数据集:

http://www.cs.toronto.edu/~kriz/cifar.html,自行下载python版本

参考的是斯坦福大学CS231N教程的notes

实现程序如下:

需要指出的是数据集有些大,如果仅仅用于测试程序对不对这里只选择了1000个图片进行训练。如果真的需要把索引去掉即是完整的训练和测试了


import pickle as p
import matplotlib.pyplot as plt
import numpy as np


# NearestNeighbor class
class NearestNeighbor(object):
    def __init__(self):
        pass

    def train(self, X, y):
        """ X is N x D where each row is an example. Y is 1-dimension of size N """
        # the nearest neighbor classifier simply remembers all the training data
        self.Xtr = X
        self.ytr = y

    def predict(self, X):
        """ X is N x D where each row is an example we wish to predict label for """
        num_test = X.shape[0]
        # lets make sure that the output type matches the input type
        Ypred = np.zeros(num_test, dtype=self.ytr.dtype)

        # loop over all test rows
        for i in range(num_test):
            # find the nearest training image to the i'th test image
            # using the L1 distance (sum of absolute value differences)
            distances = np.sum(np.abs(self.Xtr - X[i, :]), axis=1)
            min_index = np.argmin(distances)  # get the index with smallest distance
            Ypred[i] = self.ytr[min_index]  # predict the label of the nearest example

        return Ypred


def load_CIFAR_batch(filename):
    """ load single batch of cifar """
    with open(filename, 'rb')as f:
        datadict = p.load(f, encoding='latin1')
        X = datadict['data']
        Y = datadict['labels']
        Y = np.array(Y)  # 字典里载入的Ylist类型,把它变成array类型
        return X, Y


def load_CIFAR_Labels(filename):
    with open(filename, 'rb') as f:
        label_names = p.load(f, encoding='latin1')
        names = label_names['label_names']
        return names
# load data
label_names = load_CIFAR_Labels("cifar-10-batches-py/batches.meta")
imgX1, imgY1 = load_CIFAR_batch("cifar-10-batches-py/data_batch_1")
imgX2, imgY2 = load_CIFAR_batch("cifar-10-batches-py/data_batch_2")
imgX3, imgY3 = load_CIFAR_batch("cifar-10-batches-py/data_batch_3")
imgX4, imgY4 = load_CIFAR_batch("cifar-10-batches-py/data_batch_4")
imgX5, imgY5 = load_CIFAR_batch("cifar-10-batches-py/data_batch_5")
Xte_rows, Yte = load_CIFAR_batch("cifar-10-batches-py/test_batch")

Xtr_rows = np.concatenate((imgX1, imgX2, imgX3, imgX4, imgX5))
Ytr_rows = np.concatenate((imgY1, imgY2, imgY3, imgY4, imgY5))

nn = NearestNeighbor()  # create a Nearest Neighbor classifier class
nn.train(Xtr_rows[:1000,:], Ytr_rows[:1000])  # train the classifier on the training images and labels
Yte_predict = nn.predict(Xte_rows[:100,:])  # predict labels on the test images
# and now print the classification accuracy, which is the average number
# of examples that are correctly predicted (i.e. label matches)
print('accuracy: %f' % (np.mean(Yte_predict == Yte[:100])))

# show a picture
image=imgX1[6,0:1024].reshape(32,32)
print(image.shape)
plt.imshow(image,cmap=plt.cm.gray)
plt.axis('off')    #去除图片边上的坐标轴
plt.show()

你可能感兴趣的:(信息论,机器学习)