人工智能(AI)之KNN的基本实现

数据集下载地址点我下载
本文主要介绍KNN的实现思想:

  1. KNN的主要思想就是:通过计算训练集与测试集之间的距离(欧氏距离、余弦距离、曼哈顿距离等),然后取出最相似的前N个数据对测试集进行预测
  2. 通过测试之后发现,就本次的数据集而言,把余弦距离以及欧氏距离进行加权来确定预测值结果较好,但仅仅是对于本次的训练数据而言
  3. KNN当中也还有很多细节可以去优化的,比如说对数据集进行一定的归一化,而归一化的方法也是很多的,具体怎么取,也是要看当前的数据集,找到适合的才是最好的
  4. 总之对于预测,找好模型才是最重要的,框架确定之后,再来讨论具体的优化会更有效果
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

using namespace std;
#define ANGER 0
#define DISGUST 1
#define FEAR 2
#define JOY 3
#define SAD 4
#define SURPRISE 5 

char c[300];
priority_queue<double,vector<double>,greater<double> >q;
map<double,int>map1; //从小到大
map<double,int, greater<double> >map2; //从大到小double> >两者空格不可少 
const string Str1 = "train", Str2 = "test";
set<string> sets;
bool vector_old[2000][4000];
double vector2[2000][4000];
double proba[9][2000];
double newproba[9][2000];
double dis_save[2000];
double K;
int num1=0;

void readanger()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/anger_train.txt");
    int i = 0;
    while (in && i < 246){
        memset(c, 0, sizeof(c));
        in.getline(c, 300);
        string s;
        s.append(c, 300);
        stringstream ss(s);
        ss >> s; // 第一个单词不用
        double d;
        ss >> d;
        proba[ANGER][i++] = d;
    }
    in.close();
}

void readdisgust()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/disgust_train.txt");
    int i = 0;
    while (in && i < 246){
        memset(c, 0, sizeof(c));
        in.getline(c, 300);
        string s;
        s.append(c, 300);
        stringstream ss(s);
        ss >> s; // 第一个单词不用
        double d;
        ss >> d;
        proba[DISGUST][i++] = d;
    }
    in.close();
}

void readfear()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/fear_train.txt");
    int i = 0;
    while (in && i < 246){
        memset(c, 0, sizeof(c));
        in.getline(c, 300);
        string s;
        s.append(c, 300);
        stringstream ss(s);
        ss >> s; // 第一个单词不用
        double d;
        ss >> d;
        proba[FEAR][i++] = d;
    }
    in.close();
}

void readjoy()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/joy_train.txt");
    int i = 0;
    while (in && i < 246){
        memset(c, 0, sizeof(c));
        in.getline(c, 300);
        string s;
        s.append(c, 300);
        stringstream ss(s);
        ss >> s; // 第一个单词不用
        double d;
        ss >> d;
        proba[JOY][i++] = d;
    }
    in.close();
}

void readsad()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/sad_train.txt");
    int i = 0;
    while (in && i < 246){
        memset(c, 0, sizeof(c));
        in.getline(c, 300);
        string s;
        s.append(c, 300);
        stringstream ss(s);
        ss >> s; // 第一个单词不用
        double d;
        ss >> d;
        proba[SAD][i++] = d;
    }
    in.close();
}

void readsurprise()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/surprise_train.txt");
    int i = 0;
    while (in && i < 246){
        memset(c, 0, sizeof(c));
        in.getline(c, 300);
        string s;
        s.append(c, 300);
        stringstream ss(s);
        ss >> s; // 第一个单词不用
        double d;
        ss >> d;
        proba[SURPRISE][i++] = d;
    }
    in.close();
}

void get_proba()
{
    readanger();
    readdisgust();
    readfear();
    readsad();
    readjoy();
    readsurprise();
}

void get_word()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Dataset_words.txt");
    ofstream out("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/anger.txt");
    string str;
    int i = 0;
    if(in&&out)
    {
        while(getline(in,str))
        {
            if(i==0)
            {
                i++;
                continue;
            }
            else 
            {
                int j = 0;
                stringstream ss;
                ss << str;
                while(!ss.eof())
                {
                    {
                        if(j==0)
                        {
                            j++;
                            ss >> str;
                            str = " ";
                            sets.insert(str);
                        }
                        //cout << str <
                        else
                        {
                            ss >> str;
                            sets.insert(str);
                        }
                    }
                }
            }
        }
    }else{
        cerr<<"open in or out file error"<for(set<string>::iterator it = sets.begin();it != sets.end();it++)
    {
        if(*it != " ")
        {
            out << *it << endl;
            //cout << *it << endl;
        }

    }
    in.close();
    out.close();
}


void clear_stopwords()
{
    fstream in;
    in.open("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Foxstoplist (1).txt");
    ofstream out("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Foxstoplistout.txt");
    string str;
    if(in)
    {
        while(getline(in,str))
        {
            stringstream ss;
            ss << str;
            while(!ss.eof())
            {
                ss >> str;
                out << str <for(set<string>::iterator it = sets.begin();it != sets.end();)
                {
                    if(*it == str)
                    {
                        sets.erase(it);
                        break;
                    }
                    else
                    {
                        it++;
                    }
                }
            }
        }
    }
    in.close();
    out.close();
}


void vector_out()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Dataset_words.txt");
    ofstream out("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/vector.txt");
    string str;
    int i = 0;
    int row_num = 0;
    while(in&&out)
    {
        while(getline(in,str))
        {

            if(i==0)
            {
                i++;
                continue;
            }
            else
            {
                int j = 0;
                stringstream ss;
                ss << str;
                while(!ss.eof())
                {
                    int lin_num = 0;
                    if(j==0)
                    {
                        j++;
                        ss >> str;
                    }
                    else
                    {
                        ss >> str;
                        for(set<string>::iterator it=sets.begin(); it != sets.end() ; it++)
                        {
                            if(*it == str)
                            {
                                vector_old[row_num][lin_num] = true;
                            }
                            lin_num++;
                        }
                    }
                }
            }
            row_num++;
        }
    }
    string wenben = "文本编号 ";
    out << wenben;
    for(set<string>::iterator it= sets.begin(); it != sets.end(); it++)
    {
        out << *it << " ";
    }
    in.close();
    out.close();
}


void compute_dis(double K)
{
    for (int i = 0; i < 1246; i++){
    double sum = 0;
    for (int j = 0; j < sets.size(); j++){
        if (vector_old[i][j]) 
        {
            sum++;
        }
    }       
    for (int j = 0; j < sets.size(); j++){
        vector2[i][j] = vector_old[i][j]*1.0/sum;
        //out << vector2[i][j] << " ";
    }
        //out <
    }

    for(int mood_n = 0 ; mood_n < 6 ; mood_n++)
    {
        for(int i = 0 ; i < 1000 ; i++)
        {

            int dis_num=0;
            double pro_sum = 0;
            double dis;
            int pos;
            double max_dis = 0;
            double min_dis = 10000;
            map<double,int>map1;
            map<double,int, greater<double> >map2;

            for(int j = 0 ; j < 246 ; j++)
            {
                dis = 0;

                double angle = 0;
                double xy_sum=0;
                double xx=0;
                double yy=0;
                double save_angle[2000]={0};

                for(int k = 0 ; k < sets.size() ; k++)
                {
                    xy_sum+=vector_old[i+246][k]*vector_old[j][k];
                    xx+=vector_old[i+246][k]*vector_old[i+246][k];
                    yy+=vector_old[j][k]*vector_old[j][k];

                    //dis += (vector2[i+246][k]-vector2[j][k])*(vector2[i+246][k]-vector2[j][k]);
                }



                dis_save[j] = xx + yy - 2*xy_sum; 
                angle = xy_sum/(sqrt(xx)*sqrt(yy));
                //angle = angle*(1/sqrt(dis_save[j]));
                angle = 0.8*angle + 0.2*dis_save[j];
                map2.insert(make_pair(angle,j));

                /*
                for(int k = 0 ; k < sets.size() ; k++)
                {
                    dis += (vector2[i+246][k]-vector2[j][k])*(vector2[i+246][k]-vector2[j][k]);
                }
                dis = sqrt(dis);
                dis_sum+=dis;
                map1.insert(make_pair(dis,j));
                */
            }
            cout << "i:" << i </*
            for(map::iterator it1 = map1.begin();it1!=map1.end();it1++)
            {
                double temp = it1->first;
                temp = temp/dis_sum;
                map1.insert(make_pair(temp,it1->second));
            }
            */


            int K_i = 1;
            double dis_sum = 0;
            for(map<double,int>::iterator it = map2.begin(); it != map2.end(); it++)
            {
                if(K_i>K)
                {
                    break;
                }
                else
                {
                    K_i++;
                    pro_sum += proba[mood_n][it->second];
                    //dis_sum+=(1/(dis_save[it->second]*dis_save[it->second]));
                    /*
                    for(map::iterator it1 = map1.begin();it1!=map1.end();it1++)
                    {
                        if(it->second == it1->second)
                        {
                            pro_sum = pro_sum + 0.6*
                            break;
                        }
                    }
                    */
                }
            }
            newproba[mood_n][i] = pro_sum*1.0/K;
        }
    }
    cout << "happy" <void print()
{
    for(int i = 0 ; i < 6 ; i++)
    {

        ofstream f;
        switch(i)
        {
            case ANGER:    f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/anger_predict.txt"); break;
            case DISGUST:  f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/disgust_predict.txt"); break;
            case FEAR:     f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/fear_predict.txt"); break;
            case JOY:      f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/joy_predict.txt"); break;
            case SAD:      f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/sad_predict.txt"); break;
            case SURPRISE: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/surprise_predict.txt"); break;
        }
        for(int j = 0 ; j < 1000 ; j++)
        {
            f << newproba[i][j] <//cout << newproba[i][j] <
        }
        f.close();
    }
}


int main()
{
    cout << "请输入k:" <cin >> K;
    get_word();
    cout << 0 <cout << 1 <cout << 2 <cout << 3 <cout << 4 <cout << 5 <cout << sets.size() <return 0;
}

你可能感兴趣的:(AI,人工智能,KNN,预测,AI,优化)