应用支持向量机(SVM)实现图像分类——Python

应用支持向量机(SVM)实现图像分类——Python

文章目录

    • 1.代码运行
    • 2.注意事项
    • 3.代码分析
    • 4.源代码

1.代码运行

  1. 输入 1 测试一张图片并预测结果

  2. 输入 2 对测试集整体进行测试,得出准确率(2分钟左右)

  3. 输入其他数字自动退出程序

2.注意事项

  1. 本程序包含python库较多,请自行配置(pip),如有需求,请评论或私信
    应用支持向量机(SVM)实现图像分类——Python_第1张图片

  2. 回复其他数字会自动退出程序

  3. 输入图片要求是28*28像素

  4. 模型训练大概需要5分钟,请耐心等候!

  5. 本代码使用本地MNIST数据库,请将代码放至合适位置(同目录下有raw文件夹,文件夹内有以下四个文件应用支持向量机(SVM)实现图像分类——Python_第2张图片

3.代码分析

  1. 加载MNIST数据库
def load_mnist_train(path, kind='train'):
    labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
    images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8))
        # >II表示以大端模式读取2个int,大端模式是指二进制中最高位在左边。
        labels = np.fromfile(lbpath, dtype=np.uint8)
    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
    return images, labels

def load_mnist_test(path, kind='t10k'):
    labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
    images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8))
        labels = np.fromfile(lbpath, dtype=np.uint8)
    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
    return images, labels

2.训练模型


def creat_model():
    if not os.path.exists('HW.model'):
        # 标准化
        X = preprocessing.StandardScaler().fit_transform(train_images)
        X_train = X[0:60000]
        y_train = train_labels[0:60000]

        # 模型训练
        print(time.strftime('%Y-%m-%d %H:%M:%S'))
        print('开始训练模型,大概需要5分钟,请耐心等候!!!')
        model_svc = svm.SVC()
        model_svc.fit(X_train, y_train)
        print(time.strftime('%Y-%m-%d %H:%M:%S'))

        joblib.dump(model_svc, 'HW.model')
    else:
        print(time.strftime('%Y-%m-%d %H:%M:%S'))
        print('已经存在model文件')

3.测试单张和多张


def Test_one(imgPath,modelPath):
    model_svc=modelPath
    image = preprocessing.StandardScaler().fit_transform(np.array(Image.open(imgPath).convert('L'), dtype=np.uint8))
    print('预测结果是:', model_svc.predict([image.reshape(-1), image.reshape(-1)])[0], '\n')


def Test_all():
    model_svc = joblib.load('HW.model')
    # 评分并预测
    x = preprocessing.StandardScaler().fit_transform(test_images)
    x_test = x[0:10000]
    y_pred = test_labels[0:10000]
    print(model_svc.score(x_test, y_pred))
    y = model_svc.predict(x_test)

4.主程序

if __name__ == '__main__':
    print('本程序会默认加载测试集、训练集与模型,只需将四个文件置于raw文件夹即可,模型文件会生成在本文件同目录下')
    train_images, train_labels = load_mnist_train(path)
    test_images, test_labels = load_mnist_test(path)
    creat_model()
    print('数字格式为28*28像素,路径没有要求!!!')
    print('回复数字1为测试一张图片,回复数字2为测试测试集准确率,回复其他数字自动退出程序!!!\n')
    while 1:
        ans = eval(input('测试一张(1)还是测试集准确率(2):'))
        if ans == 1:
            modelPath = joblib.load('HW.model') # 模型文件会生成在本文件同目录下不需要选择,这里Path默认直接加载模型
            root = tk.Tk()
            root.withdraw()
            print('请选择测试图片')
            imgPath = filedialog.askopenfilename()
            Test_one(imgPath,modelPath)
        elif ans == 2:
            Test_all()
        else:
            exit()

4.源代码

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :HW5.py
# @Time      :2023/4/8 22:56
# @Author    :YKW
import struct
import os
import time
import joblib
import numpy as np
import tkinter as tk
from tkinter import filedialog
from PIL import Image
from sklearn import svm
from sklearn import preprocessing

path = './raw/'  # MNIST源文件目录


def load_mnist_train(path, kind='train'):
    labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
    images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8))
        # >II表示以大端模式读取2个int,大端模式是指二进制中最高位在左边。
        labels = np.fromfile(lbpath, dtype=np.uint8)
    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
    return images, labels


def load_mnist_test(path, kind='t10k'):
    labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
    images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8))
        labels = np.fromfile(lbpath, dtype=np.uint8)
    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
    return images, labels


def creat_model():
    if not os.path.exists('HW.model'):
        # 标准化
        X = preprocessing.StandardScaler().fit_transform(train_images)
        X_train = X[0:60000]
        y_train = train_labels[0:60000]

        # 模型训练
        print(time.strftime('%Y-%m-%d %H:%M:%S'))
        print('开始训练模型,大概需要5分钟,请耐心等候!!!')
        model_svc = svm.SVC()
        model_svc.fit(X_train, y_train)
        print(time.strftime('%Y-%m-%d %H:%M:%S'))

        joblib.dump(model_svc, 'HW.model')
    else:
        print(time.strftime('%Y-%m-%d %H:%M:%S'))
        print('已经存在model文件')


def Test_one(imgPath,modelPath):
    model_svc=modelPath
    image = preprocessing.StandardScaler().fit_transform(np.array(Image.open(imgPath).convert('L'), dtype=np.uint8))
    print('预测结果是:', model_svc.predict([image.reshape(-1), image.reshape(-1)])[0], '\n')


def Test_all():
    model_svc = joblib.load('HW.model')
    # 评分并预测
    x = preprocessing.StandardScaler().fit_transform(test_images)
    x_test = x[0:10000]
    y_pred = test_labels[0:10000]
    print(model_svc.score(x_test, y_pred))
    y = model_svc.predict(x_test)


if __name__ == '__main__':
    print('本程序会默认加载测试集、训练集与模型,只需将四个文件置于raw文件夹即可,模型文件会生成在本文件同目录下')
    train_images, train_labels = load_mnist_train(path)
    test_images, test_labels = load_mnist_test(path)
    creat_model()
    print('数字格式为28*28像素,路径没有要求!!!')
    print('回复数字1为测试一张图片,回复数字2为测试测试集准确率,回复其他数字自动退出程序!!!\n')
    while 1:
        ans = eval(input('测试一张(1)还是测试集准确率(2):'))
        if ans == 1:
            modelPath = joblib.load('HW.model') # 模型文件会生成在本文件同目录下不需要选择,这里Path默认直接加载模型
            root = tk.Tk()
            root.withdraw()
            print('请选择测试图片')
            imgPath = filedialog.askopenfilename()
            Test_one(imgPath,modelPath)
        elif ans == 2:
            Test_all()
        else:
            exit()

你可能感兴趣的:(python基础,python,支持向量机,分类,机器学习,sklearn)