基于KNN的MNIST图像分类

KNN

  • 基于KNN的MNIST图像分类
    • 1 核心思想
    • 2 实验环境
    • 3 实现代码
      • 3.1 数据预处理
      • 3.2 KNN实现
    • 4 测试分析
      • 4.1 混淆矩阵
      • 4.2 K取值的影响
        • (1)K对正确率的影响
        • (2)K对运行时间的影响

基于KNN的MNIST图像分类

1 核心思想

KNN算法的核心思想是,如果一个样本在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。

2 实验环境

Windows 10 20H2
Python 3.8.6
Numpy 1.19.5

3 实现代码

import numpy as np
import struct
import time

3.1 数据预处理

# 训练集文件
train_images_idx3_ubyte_file = 'MNIST_data/train-images.idx3-ubyte'
# 训练集标签文件
train_labels_idx1_ubyte_file = 'MNIST_data/train-labels.idx1-ubyte'
# 测试集文件
test_images_idx3_ubyte_file = 'MNIST_data/t10k-images.idx3-ubyte'
# 测试集标签文件
test_labels_idx1_ubyte_file = 'MNIST_data/t10k-labels.idx1-ubyte'
# 读入图像
def ReadImgFile(filepath):
    with open(filepath, 'rb') as f:
        _, img_num, img_h, img_w = struct.unpack('>4I', f.read(16))
        img = np.fromfile(f, dtype=np.uint8).reshape(img_num, img_h * img_w)
        return img
# 读入图像标签
def ReadLableFile(filepath):
    with open(filepath, 'rb') as f:
        _, img_num = struct.unpack('>2I', f.read(8))
        label = np.fromfile(f, dtype=np.uint8).reshape(img_num, 1)
        return label
# 读取训练集和测试集的图像和标签
train_set = ReadImgFile(train_images_idx3_ubyte_file)
train_labels = ReadLableFile(train_labels_idx1_ubyte_file)
test_set = ReadImgFile(test_images_idx3_ubyte_file)
test_labels = ReadLableFile(test_labels_idx1_ubyte_file)
# 将所有数据转成np.float32类型
train_set = train_set.astype(np.float32)
train_labels = train_labels.astype(np.float32)
test_set = test_set.astype(np.float32)
test_labels = test_labels.astype(np.float32)

# 取数据集的1/10的数据
train_set = train_set[:train_set.shape[0]//10]
train_labels = train_labels[:train_labels.shape[0]//10]
test_set = test_set[:test_set.shape[0]//10]
test_labels = test_labels[:test_labels.shape[0]//10]

# 获取数据集数量
train_num = train_set.shape[0]
test_num = test_set.shape[0]

3.2 KNN实现

# 距离函数
def distance(test_image, train_image):
    return [np.sqrt(np.sum(np.square(test_image - train_image)))]  # L2距离
def knn(test_image, train_set, train_labels, k):
    # 获取test_image与测试集所有的距离值
    result = [
        {"distance": distance(
            test_image, train_set[i]), "label":train_labels[i]}
        for i in range(train_num)
    ]
    # 按距离值从小到大排序
    result = sorted(result, key=lambda item: item["distance"])
    # 取距离值最小的k个
    result_k = result[:k]
    # 从距离值最小的k个中,取标签的众数
    prediction = np.zeros(10, dtype=int)
    for res in result_k:
        prediction[int(res["label"])] += 1
    return np.argmax(prediction)

4 测试分析

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

4.1 混淆矩阵

# k=5时的混淆矩阵
k = 5
pred_set = []
correct = 0
for i in range(test_num):
    pred = knn(test_set[i], train_set, train_labels, k)
    if(pred == test_labels[i]):
        correct += 1
    pred_set.append(pred)
confusion_matrix(test_labels, pred_set)
array([[ 83,   0,   0,   0,   0,   0,   2,   0,   0,   0],
       [  0, 126,   0,   0,   0,   0,   0,   0,   0,   0],
       [  2,   4,  98,   1,   1,   0,   2,   6,   2,   0],
       [  0,   1,   0,  98,   0,   2,   2,   2,   0,   2],
       [  0,   2,   0,   0,  99,   0,   1,   1,   0,   7],
       [  1,   1,   0,   0,   1,  81,   1,   0,   1,   1],
       [  2,   0,   0,   0,   1,   0,  84,   0,   0,   0],
       [  0,   6,   0,   0,   1,   1,   0,  89,   0,   2],
       [  3,   1,   1,   4,   1,   3,   2,   0,  71,   3],
       [  0,   0,   0,   0,   4,   0,   0,   1,   2,  87]], dtype=int64)

4.2 K取值的影响

# 取不同的K值,测试正确率和运行时间
k_set = range(51)[1:]
accuracy_set = []
time_set = []
for k in k_set:
    correct = 0
    time_start = time.time()
    for i in range(test_num):
        pred = knn(test_set[i], train_set, train_labels, k)
        if(pred == test_labels[i]):
            correct += 1
    time_end = time.time()
    accuracy = correct/test_num*100
    cost_time = time_end-time_start
    accuracy_set.append(accuracy)
    time_set.append(cost_time)
    print("k = {}时,正确率:{:.2f}%,运行时间:{:.2f}s.".format(k, accuracy, cost_time))
k = 1时,正确率:90.40%,运行时间:59.39s.
k = 2时,正确率:90.00%,运行时间:59.19s.
k = 3时,正确率:91.30%,运行时间:59.06s.
k = 4时,正确率:91.30%,运行时间:60.83s.
k = 5时,正确率:91.60%,运行时间:59.53s.
k = 6时,正确率:91.30%,运行时间:58.84s.
k = 7时,正确率:91.40%,运行时间:59.10s.
k = 8时,正确率:91.30%,运行时间:59.19s.
k = 9时,正确率:90.90%,运行时间:59.01s.
k = 10时,正确率:90.60%,运行时间:59.12s.
k = 11时,正确率:90.20%,运行时间:59.02s.
k = 12时,正确率:90.10%,运行时间:58.99s.
k = 13时,正确率:89.80%,运行时间:59.10s.
k = 14时,正确率:89.50%,运行时间:58.90s.
k = 15时,正确率:89.60%,运行时间:59.05s.
k = 16时,正确率:89.50%,运行时间:59.09s.
k = 17时,正确率:89.50%,运行时间:59.13s.
k = 18时,正确率:89.80%,运行时间:58.87s.
k = 19时,正确率:89.20%,运行时间:59.14s.
k = 20时,正确率:88.90%,运行时间:59.68s.
k = 21时,正确率:88.70%,运行时间:59.37s.
k = 22时,正确率:88.60%,运行时间:58.92s.
k = 23时,正确率:88.40%,运行时间:59.05s.
k = 24时,正确率:88.30%,运行时间:59.13s.
k = 25时,正确率:88.10%,运行时间:59.04s.
k = 26时,正确率:87.90%,运行时间:58.90s.
k = 27时,正确率:87.80%,运行时间:58.95s.
k = 28时,正确率:87.70%,运行时间:60.57s.
k = 29时,正确率:87.70%,运行时间:59.21s.
k = 30时,正确率:87.50%,运行时间:59.02s.
k = 31时,正确率:87.50%,运行时间:58.97s.
k = 32时,正确率:87.40%,运行时间:58.93s.
k = 33时,正确率:87.40%,运行时间:58.93s.
k = 34时,正确率:87.10%,运行时间:58.90s.
k = 35时,正确率:86.80%,运行时间:58.97s.
k = 36时,正确率:87.20%,运行时间:59.15s.
k = 37时,正确率:87.00%,运行时间:58.98s.
k = 38时,正确率:86.90%,运行时间:58.85s.
k = 39时,正确率:86.60%,运行时间:59.00s.
k = 40时,正确率:86.90%,运行时间:59.02s.
k = 41时,正确率:86.70%,运行时间:59.13s.
k = 42时,正确率:86.80%,运行时间:58.87s.
k = 43时,正确率:86.80%,运行时间:58.96s.
k = 44时,正确率:86.80%,运行时间:59.28s.
k = 45时,正确率:87.10%,运行时间:59.05s.
k = 46时,正确率:87.00%,运行时间:58.89s.
k = 47时,正确率:87.00%,运行时间:59.40s.
k = 48时,正确率:86.70%,运行时间:59.16s.
k = 49时,正确率:86.40%,运行时间:59.17s.
k = 50时,正确率:86.40%,运行时间:58.82s.

(1)K对正确率的影响

# 绘制K对正确率的影响图
plt.plot(k_set, accuracy_set)
plt.xlabel("K")
plt.ylabel("Accuracy")
plt.show()

基于KNN的MNIST图像分类_第1张图片

结论:K的取值对于正确率的影响较大,通过测试,可以发现K在5附近时,正确率最高。
如果选择较小的K值,就相当于用较小的领域中的训练实例进行预测,“学习”近似误差会减小,只有与输入实例较近或相似的训练实例才会对预测结果起作用,与此同时带来的问题是“学习”的估计误差会增大,换句话说,K值的减小就意味着整体模型变得复杂,容易发生过拟合。
如果选择较大的K值,就相当于用较大领域中的训练实例进行预测,其优点是可以减少学习的估计误差,但缺点是学习的近似误差会增大。这时候,与输入实例较远(不相似的)训练实例也会对预测器作用,使预测发生错误,且K值的增大就意味着整体的模型变得简单。

(2)K对运行时间的影响

# 绘制K对运行时间的影响图
plt.plot(k_set, time_set)
plt.xlabel("K")
plt.ylabel("cost_time")
plt.show()

基于KNN的MNIST图像分类_第2张图片

结论:通过实验和代码分析,K的取值对运行效率影响较小。

你可能感兴趣的:(深度学习,python,深度学习,机器学习)