k-Nearest 最近邻分类算法

概述

kNN算法又称为k最近邻(k-nearest neighbor classification)分类算法。所谓的k最近邻,就是指最接近的k个邻居(数据),即每个样本都可以由它的K个邻居来表达。
kNN算法的核心思想是,在一个含未知样本的空间,可以根据离这个样本最邻近的k个样本的数据类型来确定样本的数据类型。
该算法涉及3个主要因素:训练集、距离与相似的衡量、k的大小;主要考虑因素:距离与相似度。

opencv中使用
        Mat img = oneMat;
        Mat gray;
        cvtColor(img, gray, CV_BGR2GRAY);
        int b = 20;
        int m = gray.rows / b;   //原图为1000*2000
        int n = gray.cols / b;   //裁剪为5000个20*20的小图块
        Mat data,labels;   //特征矩阵
        for (int i = 0; i < n; i++)
        {
            int offsetCol = i*b; //列上的偏移量
            for (int j = 0; j < m; j++)
            {
                int offsetRow = j*b;  //行上的偏移量
                //截取20*20的小块
                Mat tmp;
                gray(Range(offsetRow, offsetRow + b), Range(offsetCol, offsetCol + b)).copyTo(tmp);
                data.push_back(tmp.reshape(0,1));  //序列化后放入特征矩阵
                int l=(int)j / 5;
                labels.push_back(l);  //对应的标注
                LOGI("jason %d", l);
            }

        }
        data.convertTo(data, CV_32F); //uchar型转换为cv_32f
        int samplesNum = data.rows;
        int trainNum = 3000;
        Mat trainData, trainLabels;
        trainData = data(Range(0, trainNum), Range::all());   //前3000个样本为训练数据
        trainLabels = labels(Range(0, trainNum), Range::all());

        //使用KNN算法
        int K = 5;
        Ptr<TrainData> tData = TrainData::create(trainData, ROW_SAMPLE, trainLabels);//降训练数据封装成一个TrainData对象,送入train函数
        Ptr<KNearest> model = KNearest::create();
        model->setDefaultK(K);
        model->setIsClassifier(true);
        model->train(tData);

        //预测分类
        double train_hr = 0, test_hr = 0;
        // compute prediction error on train and test data
        for (int i = 0; i < samplesNum; i++)
        {
            Mat sample = data.row(i);
            float r = model->predict(sample);   //对所有行进行预测
            //预测结果与原结果相比,相等为1,不等为0
            r = std::abs(r - labels.at<int>(i)) <= FLT_EPSILON ? 1.f : 0.f;

            if (i < trainNum)
                train_hr += r;  //累积正确数
            else
                test_hr += r;
        }

        test_hr /= samplesNum - trainNum;
        train_hr = trainNum > 0 ? train_hr / trainNum : 1.;

        LOGI("accuracy: train = %.1f%%, test = %.1f%%\n",
               train_hr*100., test_hr*100.);

你可能感兴趣的:(Android,C语言)