lda c++代码

#include

using namespace std;

typedef long long ll;

const int INF=0x3f3f3f3f;
const int MAX_M=1e4+5;
const int MAX_K=30+5;
const int MAX_N=1e6+5;

struct node{
    int id,k;//词的编码、词的主题
    node(){}
    node(int id_,int k_):id(id_),k(k_){}
};

int M,K,N;//文档数、主题数、词数
string dic[MAX_N+1];//词表
map<string,int> mp;//词表map
double doc[MAX_M+1][MAX_K+1];//每篇文档的主题分布
double top[MAX_K+1][MAX_N+1];//每个主题的词分布
vector<node> doc_words[MAX_M+1];//每篇文档中的词
double alpha[MAX_K+1],beta[MAX_N+1];//超参数向量

double doci[MAX_M+1][MAX_K+1];//实时更新
double topi[MAX_K+1][MAX_N+1];//实时更新

vector<string> data[MAX_M];//测试数据

int multinomial_rand(double P[],int n){//多项分布随机数
    double h=(double)rand()/(RAND_MAX+1);
    double s=0;
    for(int i=1;i<=n;i++){
        s+=P[i];
        if(h<s)return i;
    }
    return n;
}

void init(){//初始化函数
    for(int i=1;i<=K;i++)alpha[i]=1;
    for(int i=1;i<=N;i++)beta[i]=1;
    for(int i=1;i<=K;i++)alpha[0]+=alpha[i];
    for(int i=1;i<=N;i++)beta[0]+=beta[i];

    for(int i=1;i<=M;i++){
        for(int j=0;j<doc_words[i].size();j++){
            node &w=doc_words[i][j];
            w.k=rand()%K+1;
            doci[i][w.k]++;doci[i][0]++;
            topi[w.k][w.id]++;topi[w.k][0]++;
        }
    }
}

void gibbs_sampling(int T){//GibbsSampling采样
    while(T--){
        for(int i=1;i<=M;i++){
            for(int j=0;j<doc_words[i].size();j++){
                node &w=doc_words[i][j];
                double P[MAX_K+1];
                for(int k=1;k<=K;k++){
                    double t=w.k==k?1:0;
                    double ans1=(doci[i][k]+alpha[k]-t)/(doci[i][0]+alpha[0]-1);
                    double ans2=(topi[k][w.id]+beta[w.id]-t)/(topi[k][0]+beta[0]-t);
                    P[k]=ans1*ans2;
                }
                double s=0;
                for(int k=1;k<=K;k++)s+=P[k];
                for(int k=1;k<=K;k++)P[k]/=s;

                int k0=w.k;
                int k1=multinomial_rand(P,K);
                w.k=k1;

                doci[i][k0]--;doci[i][k1]++;
                topi[k0][w.id]--;topi[k0][0]--;
                topi[k1][w.id]++;topi[k1][0]++;

                doc[i][k1]++;
                top[k1][w.id]++;
            }
        }
    }
}

void normalize(){//矩阵归一化
    for(int i=1;i<=M;i++){
        double s=0;
        for(int j=1;j<=K;j++)s+=doc[i][j];
        for(int j=1;j<=K;j++)doc[i][j]/=s;
    }
    for(int i=1;i<=K;i++){
        double s=0;
        for(int j=1;j<=N;j++)s+=top[i][j];
        for(int j=1;j<=N;j++)top[i][j]/=s;
    }
}

int kkk;
int rak[MAX_N+1];
bool cmp(int i,int j){
    return top[kkk][i]>top[kkk][j];
}

void show(){
    for(int i=1;i<=K;i++){
        cout<<i<<":";
        kkk=i;
        for(int j=1;j<=N;j++)rak[j]=j;
        sort(rak+1,rak+N+1,cmp);
        for(int j=1;j<=10;j++){
            int id=rak[j];
            cout<<dic[id]<<"*";
            printf("%.4f   ",top[i][id]);
        }
        cout<<endl;
    }
}

int main(){
    srand(time(NULL));
    freopen("C:\\Users\\28612\\Desktop\\data","r",stdin);
    M=0,K=18,N=0;
    cout<<"主题数K:"<<K<<endl;
    int t;
    while(cin>>t){
        M++;
        while(t--){
            string s;
            cin>>s;
            data[M].push_back(s);
        }
    }
    cout<<"文档数M:"<<M<<endl;

    for(int i=1;i<=M;i++){
        for(int j=0;j<data[i].size();j++){
            string s=data[i][j];
            if(mp[s]==0){
                dic[++N]=s;
                mp[s]=N;
            }
        }
    }
    cout<<"词汇数N:"<<N<<endl;

    for(int i=1;i<=M;i++){
        for(int j=0;j<data[i].size();j++){
            string s=data[i][j];
            int id=mp[s];
            doc_words[i].push_back(node(id,1));
        }
    }

    init();
    gibbs_sampling(100);
    normalize();
    show();
    return 0;
}

你可能感兴趣的:(nlp)