【Pytroch】基于K邻近算法的数据分类预测(Excel可直接替换数据)

【Pytroch】基于K邻近算法的数据分类预测(Excel可直接替换数据)

  • 1.模型原理
  • 2.数学公式
  • 3.文件结构
  • 4.Excel数据
  • 5.下载地址
  • 6.完整代码
  • 7.运行结果

1.模型原理

K最近邻(K-Nearest Neighbors,简称KNN)是一种简单但常用的机器学习算法,用于分类和回归问题。它的核心思想是基于已有的训练数据,通过测量样本之间的距离来进行分类预测。在实现KNN算法时,可以使用PyTorch来进行计算和操作。

下面是使用PyTorch实现KNN算法的一般步骤:

  1. 准备数据集:首先,需要准备训练数据集,包括样本特征和对应的标签。

  2. 计算距离:对于每个待预测的样本,计算它与训练数据集中每个样本的距离。常见的距离度量包括欧氏距离、曼哈顿距离等。

  3. 排序与选择:将计算得到的距离按照从小到大的顺序进行排序,并选择距离最近的K个样本。

  4. 投票或平均:对于分类问题,选择K个样本中出现最多的类别作为预测结果;对于回归问题,选择K个样本的标签的平均值作为预测结果。

2.数学公式

当使用K最近邻(KNN)算法进行数据分类预测时,以下是其基本原理的数学描述:

  1. 距离度量:假设我们有一个训练数据集 D D D,其中包含 n n n 个样本。每个样本 x i x_i xi 都有 m m m 个特征,可以表示为 x i = ( x i 1 , x i 2 , … , x i m ) x_i = (x_{i1}, x_{i2}, \ldots, x_{im}) xi=(xi1,xi2,,xim)。对于一个待预测的样本 x new x_{\text{new}} xnew,我们需要计算它与训练集中每个样本的距离。常见的距离度量方式包括欧氏距离(Euclidean Distance)和曼哈顿距离(Manhattan Distance)等:

    • 欧氏距离: d ( x i , x new ) = ∑ j = 1 m ( x i j − x new , j ) 2 d(x_i, x_{\text{new}}) = \sqrt{\sum_{j=1}^m (x_{ij} - x_{\text{new},j})^2} d(xi,xnew)=j=1m(xijxnew,j)2

    • 曼哈顿距离: d ( x i , x new ) = ∑ j = 1 m ∣ x i j − x new , j ∣ d(x_i, x_{\text{new}}) = \sum_{j=1}^m |x_{ij} - x_{\text{new},j}| d(xi,xnew)=j=1mxijxnew,j

  2. 排序与选择:计算完待预测样本与所有训练样本的距离后,我们将距离按照从小到大的顺序排序。然后选择距离最近的 K K K 个训练样本。

  3. 投票或平均:对于分类问题,我们可以统计这 K K K 个样本中每个类别出现的次数,然后选择出现次数最多的类别作为预测结果。这就是所谓的“投票法”:

    • y ^ = argmax c ∑ i = 1 K I ( y i = c ) \hat{y} = \text{argmax}_{c} \sum_{i=1}^{K} I(y_i = c) y^=argmaxci=1KI(yi=c)

    其中, y ^ \hat{y} y^ 是预测的类别, y i y_i yi 是第 i i i 个样本的真实类别, c c c 是类别。

    对于回归问题,我们可以选择 K K K 个样本的标签的平均值作为预测结果。

总结起来,K最近邻算法的基本原理是通过测量样本之间的距离来进行分类预测。对于分类问题,通过投票法确定预测类别;对于回归问题,通过取标签的平均值来预测数值。在实际应用中,需要选择合适的距离度量和适当的 K K K 值,以及进行必要的数据预处理和特征工程。

3.文件结构

在这里插入图片描述

iris.xlsx						% 可替换数据集
Main.py							% 主函数

4.Excel数据

在这里插入图片描述

5.下载地址

- Excle资源下载地址

6.完整代码

import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt


def knn(X_train, y_train, X_test, k=5):
    X_train = torch.tensor(X_train, dtype=torch.float32)
    X_test = torch.tensor(X_test, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.long)

    predictions = []

    for i in range(X_test.shape[0]):
        distances = torch.sum((X_train - X_test[i]) ** 2, dim=1)
        _, indices = torch.topk(distances, k, largest=False)  # 获取距离最小的k个邻居的索引
        knn_labels = y_train[indices]
        pred = torch.mode(knn_labels).values  # 投票选出标签
        predictions.append(pred.item())

    return predictions

def plot_confusion_matrix(conf_matrix, classes):
    plt.figure(figsize=(8, 6))
    plt.imshow(conf_matrix, cmap=plt.cm.Blues, interpolation='nearest')
    plt.title("Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.tight_layout()
    plt.show()

def plot_predictions_vs_true(y_true, y_pred):
    plt.figure(figsize=(10, 6))
    plt.plot(y_true, 'go', label='True Labels')
    plt.plot(y_pred, 'rx', label='Predicted Labels')
    plt.title("True Labels vs Predicted Labels")
    plt.xlabel("Sample Index")
    plt.ylabel("Class Label")
    plt.legend()
    plt.show()

def main():
    # 读取Data.xlsx文件并加载数据
    data = pd.read_excel("iris.xlsx")

    # 划分特征值和标签
    features = data.iloc[:, :-1].values
    labels = data.iloc[:, -1].values

    # 将数据集拆分为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

    y_pred = knn(X_train, y_train, X_test, k=5)
    accuracy = accuracy_score(y_test, y_pred)
    print("训练集准确率:{:.2%}".format(accuracy))


    conf_matrix = confusion_matrix(y_test, y_pred)
    print("混淆矩阵:")
    print(conf_matrix)

    classes = np.unique(y_test)
    plot_confusion_matrix(conf_matrix, classes)
    plot_predictions_vs_true(y_test, y_pred)

if __name__ == "__main__":
    main()

7.运行结果

【Pytroch】基于K邻近算法的数据分类预测(Excel可直接替换数据)_第1张图片

【Pytroch】基于K邻近算法的数据分类预测(Excel可直接替换数据)_第2张图片
【Pytroch】基于K邻近算法的数据分类预测(Excel可直接替换数据)_第3张图片

你可能感兴趣的:(#,pytroch分类模型,算法,分类)