STM32单片机使用KNN算法实现鸢尾花分类

一、概述

最近在利用业余时间学习机器学习算法,由于笔者是嵌入式软件工程师,想将机器学习算法在单片机端实现,KNN算法(k-Nearest Neighbor,K最近邻算法)是为数不多的可在单片机端实现的机器学习算法。
通过检索发现,在单片机端实现KNN算法的例子较少,仅有几个用单片机实现手写数字识别的。
本例程硬件使用的是STM32F103C8T6最小系统板,IAR/MDK开发环境,STM32CubeMX进行配置并生成工程文件,鸢尾花数据集是从UCI机器学习官网https://archive.ics.uci.edu/ml/index.php下载,笔者已上传到资源,也可移步https://download.csdn.net/download/wanglong3713/26849636直接下载笔者整理好的。程序完整工程已上传至[面包多]:[001]STM32单片机使用KNN算法实现鸢尾花分类(IAR开发)
[002]STM32单片机使用KNN算法实现鸢尾花分类(Keil_MDK版)
欢迎下载。
鸢尾花Iris数据集共150组数据,分3类,分别是Iris Setosa,Iris Versicolour,Iris Virginica,每组数据有4个特征,分别是花萼长度,花萼宽度,花瓣长度,花瓣宽度。本例程选取了3个分类的各前30组共90组作为训练集,剩余20组共60组的数据作为测试集。

二、程序流程

KNN算法原理不再详细介绍,读者可自行检索,本篇博文的目的在于介绍算法的C语言实现。根据算法的原理可得程序流程:
① 计算一个测试样本,与所有训练样本的距离,距离一般选用欧氏距离,也可以选择其他距离,如切比雪夫距离、曼哈顿距离,这几种距离的定义等内容可自行检索;
② 对上述求得的距离从小到大进行排序,可使用最简单的冒泡排序法;
③ 取前k个距离,统计这k个距离中,对应的每种样本分类出现的个数,个数最大的分类,即为测试样本的分类,此处k即为KNN算法中的k;
④ 重复以上步骤,计算其他测试样本的分类。

三、主要代码

1. 计算欧式距离

根据不同的距离公式,程序稍作修改即可得到其他距离,常用距离计算公式可参考常用距离计算单片机C语言程序。

/*******************************************************************************
  * 函数名:EuclideanDistance
  * 功  能:计算一个样本与一个训练样本的欧几里得距离
  * 参  数:*u16DataA测试样本
			*u16DataB训练样本
			u8Size数据维度
  * 返回值:u16Dist距离
  * 说  明:无
*******************************************************************************/
uint16_t EuclideanDistance(uint16_t *u16DataA, uint16_t *u16DataB, uint8_t u8Size)
{
    uint16_t u16Dist = 0;
    int16_t s16Temp = 0;
    uint8_t i;
    for (i = 0; i < u8Size; i++)
    {
        s16Temp = ((int16_t)*(u16DataA + i) - ((int16_t)*(u16DataB + i)));
        s16Temp = s16Temp * s16Temp;
        u16Dist += (uint16_t)s16Temp;
    }
    u16Dist = (uint16_t)sqrt(u16Dist);
    return u16Dist;
}

2. 分类

排好序的数据,统计前k个数据中每个分类出现的个数,个数最大的结果即为分类结果;
通过printf函数打印分类情况到串口。

/*******************************************************************************
  * 函数名:KNN_Classify
  * 功  能:分类
  * 参  数:无
  * 返回值:无
  * 说  明:无
*******************************************************************************/
void KNN_Classify(void)
{
    uint8_t u8SetosaCnt = 0;
    uint8_t u8VersiColorCnt = 0;
    uint8_t u8VirginicaCnt = 0;
    uint8_t u8Max = 0;
    uint8_t i, j, m;
    uint16_t *pTest, *pTrain;
	Result_ts sIrisResult;
    for (i = 0; i < TEST_ROW; i++)
    {
        memset(&sIrisResult, 0, sizeof(sIrisResult));
        for (j = 0; j < TRAIN_ROW; j++)//
        {
            pTest = (uint16_t *)&u16TestSet[i];
            pTrain = (uint16_t *)&u16TrainSet[j];
            sIrisResult.u16Distance[j][0] = EuclideanDistance(pTest, pTrain, TRAIN_COLUMN - 1);//
            //sIrisResult.u16Distance[j][0] = ChebyshevDistance(pTest, pTrain, TRAIN_COLUMN - 1);//
			//sIrisResult.u16Distance[j][0] = ManhattanDistance(pTest, pTrain, TRAIN_COLUMN - 1);
            sIrisResult.u16Distance[j][1] = u16TrainSet[j][4];           
        }
        BubbleSort(sIrisResult.u16Distance, TRAIN_ROW);//第i个测试集的数据,排序
        u8SetosaCnt = 0;
        u8VersiColorCnt = 0;
        u8VirginicaCnt = 0;
		HAL_IWDG_Refresh(&hiwdg);
        for (m = 0; m < K_VALUE; m++)//前k个数据
        {
            switch (sIrisResult.u16Distance[m][1])
            {
                case SETOSA: u8SetosaCnt++; break;                    
                case VERSICOLOR: u8VersiColorCnt++; break;                    
                case VIRGINICA: u8VirginicaCnt++; break;
                default:break;
            }
        }
        u8Max = max(max(u8SetosaCnt, u8VersiColorCnt), u8VirginicaCnt);
        if(u8Max == u8SetosaCnt)
        {
            u8Max = SETOSA;
        }else
        {
            if (u8Max == u8VersiColorCnt)
            {
                u8Max = VERSICOLOR;
            }else
            {
                if (u8Max == u8VirginicaCnt)
                {
                    u8Max = VIRGINICA;
                }
            }
        }
        sIrisResult.u8Class = u8Max;//保存分类结果
        printf(" %.1f,%.1f,%.1f,%.1f ",(float)u16TestSet[i][0]/10,(float)u16TestSet[i][1]/10,(float)u16TestSet[i][2]/10,(float)u16TestSet[i][3]/10);//   
        switch(u8Max)
        {
            case SETOSA:
            {                    
                printf("class: Iris-setosa ");//输出分类结果
                if (sIrisResult.u8Class == u16TestSet[i][4])//分类正确
                {
                    printf(" Success\n");//
                }else
                {
                    printf(" Fail\n");
                }
            }break;
            case VERSICOLOR:
            {
                printf("class: Iris-versicolor ");//输出分类结果
                if (sIrisResult.u8Class == u16TestSet[i][4])//分类正确
                {
                    printf(" Success\n");//
                }else
                {
                    printf(" Fail\n");
                }
            }break;
            case VIRGINICA:
            {
                printf("class: Iris-virginica ");
                if (sIrisResult.u8Class == u16TestSet[i][4])//分类正确
                {
                    printf(" Success\n");//
                }else
                {
                    printf(" Fail\n");
                }
            }break;
            default:break;
        }
		HAL_IWDG_Refresh(&hiwdg);
    }
}

三、运行效果

可看出,60组测试集,有2组分类错误,58组正确,准确率为96.7%。其实用肉眼观察分析这两组数据,也可以看出确实不太好分类。另外,如果训练集和测试集选择的合适,准确率可以达到100%。
STM32单片机使用KNN算法实现鸢尾花分类_第1张图片

四、总结

1. 关于距离公式

在相关文献看到,切比雪夫距离的效果优于其他距离,但实际测试发现并非如此,可能与训练集、测试集的选取有关;

2.关于k的取值

在分类只有2种的情况下,建议k取奇数,防止两种分类出现平局的现象;但在分类有2种以上的时候,无论k是奇数还是偶数,都可能出现两种甚至多种分类出现平局,此时要考虑其他方式,选择出最佳的分类;

3. 关于训练集和测试集

分类效果和训练集、测试集的选取关系很大。本例程只是按照鸢尾花的数据集的顺序,选取了3个分类的各前30组共90组作为训练集,剩余20组共60组的数据作为测试集,如果改变训练集和测试集,最终的分类准确率不同。

你可能感兴趣的:(STM32,算法,单片机,1024程序员节,机器学习,分类,算法)