using namespace std;

#define pi 3.1415926535897932384626433832795

bool L1_flag = 1;//L1范数,0表示L2
string version;
char buf[100000], buf1[100000];
int relation_num, entity_num;
map relation2id, entity2id;
map id2entity, id2relation;

map > left_entity, right_entity;
map left_num, right_num;//int表示relaitonid

//normal distribution
double rand(double min, double max)
    return min + (max - min)*rand() / (RAND_MAX + 1.0);
double normal(double x, double miu, double sigma)
    return 1.0 / sqrt(2 * pi) / sigma*exp(-1 * (x - miu)*(x - miu) / (2 * sigma*sigma));
double randn(double miu, double sigma, double min, double max)
{// 产生正态分布的随机数
    double x, y, dScope;
        x = rand(min, max);
        y = normal(x, miu, sigma);
        dScope = rand(0.0, normal(miu, miu, sigma));
    } while (dScope>y);
    return x;
double sqr(double x)
    return x*x;
double vec_len(vector&a)
    double res = 0;
    for (int i = 0; i < a.size(); i++)
        res = res + a[i] * a[i];
    res = sqrt(res);
    return res;

class Train{
    map, map> ok;//4个int分别表示headID,relationID,tailid,状态
    void add(int headid, int tailid, int relationid)
        ok[make_pair(headid, relationid)][tailid] = 1;
    void run(int n_in, double rate_in, double margin_in, int method_in)
        n = n_in; rate = rate_in; margin = margin_in;   method = method_in;
        for (int i = 0; i < relation_num; i++)

        for (int i = 0; i < entity_num; i++)

        for (int i = 0; i < relation_tmp.size(); i++)
        for (int i = 0; i < entity_tmp.size(); i++)

        for (int i = 0; i < relation_num; i++)
            for (int j = 0; j < n; j++)
                relation_vec[i][j] = randn(0, 1.0 / n, -6 / sqrt(n), 6 / sqrt(n));
        for (int i = 0; i < entity_num; i++)
            for (int j = 0; j < n; j++)
                entity_vec[i][j] = randn(0, 1.0 / n, -6 / sqrt(n), 6 / sqrt(n));


    int n, method;
    double rate, margin;
    double res;//?
    //double count, count1;//?
    //double belta;//?
    vector fb_h, fb_t, fb_r;//fb_h保存train.txt中每行的headentity的id
    vector > feature;
    vector > relation_vec, entity_vec;//embedding?
    vector > relation_tmp, entity_tmp;

    double norm(vector&a)
    {//L2-norm of the embeddings of the entities is 1
        double mo = vec_len(a);
        if (mo > 1)
            for (int i = 0; i < a.size(); i++)
                a[i] = a[i] / mo;
        return 0;
    int rand_max(int x)//返回一[0,x)的整数
        int j = (rand()*rand()) % x;
        while (j < 0)
            j += x;
        return j;

    void bfgs()
        res = 0;//loss
        int nbatches = 100;
        int nepoch = 1000;
        int batchsize = fb_h.size() / nbatches; //fb_h.size()==train.txt样本个数
        for (int epoch = 0; epoch < nepoch; epoch++)
            res = 0;
            for (int batch = 0; batch < nbatches; batch++)
                relation_tmp = relation_vec;
                entity_tmp = entity_vec;
                for (int k = 0; k < batchsize; k++)
                    int i = rand_max(fb_h.size());
                    int j = rand_max(entity_num);//随机选择一entity id
                    double pr = 1000 * right_num[fb_r[i]] / (right_num[fb_r[i]] + left_num[fb_r[i]]);
                    if (method == 0)//均匀采样,将概率调为50%
                        pr = 500;//若均匀采样,下面的if和else则随机选择替换头实体还是尾实体
                    if (rand() % 1000 < pr)
                        while (ok[make_pair(fb_h[i], fb_r[i])].count(j)>0)//有返回1,选择负样本尾实体
                            j = rand_max(entity_num);//若train.txt中包含,则换一个尾实体
                        train_kb(fb_h[i], fb_t[i], fb_r[i], fb_h[i], j, fb_r[i]);
                        while (ok[make_pair(j, fb_r[i])].count(fb_t[i])>0)
                            j = rand_max(entity_num);
                        train_kb(fb_h[i], fb_t[i], fb_r[i], j, fb_t[i], fb_r[i]);
                    //由于entitytmp的值改变,重新限制L2-norm of the embeddings of the entities is 1
                relation_vec = relation_tmp;
                entity_vec = entity_tmp;

            cout << "epoch:" << epoch << ' ' << res << endl;

            FILE* f2 = fopen(("relation2vec." + version).c_str(), "w");
            FILE* f3 = fopen(("entity2vec." + version).c_str(), "w");
            for (int i = 0; i < relation_num;i++)
                for (int j = 0; j0) x = 1;//(h+r-t>0)
                else x = -1;
            relation_tmp[r_a][i] -= rate*x;
            entity_tmp[h_a][i] -= rate*x;
            entity_tmp[t_a][i] -= rate*x*(-1);

            x = 2 * (entity_vec[h_b][i] + entity_vec[r_b][i] - entity_vec[t_b][i]);
            if (L1_flag)//绝对值作为loss
                if (x>0) x = 1;//(h+r-t>0)
                else x = -1;
            relation_tmp[r_b][i] -=-1*rate*x;//注意要乘-1
            entity_tmp[h_b][i] -= -1*rate*x;
            entity_tmp[t_b][i] -= -1*rate*x*(-1);

    void train_kb(int h_a,int t_a,int r_a,int h_b,int t_b,int r_b)
        double posLoss = calc_sum(h_a, t_a, r_a);
        double negLoss = calc_sum(h_b, t_b, r_b);
        if (posLoss + margin - negLoss > 0)
            res += margin + posLoss - negLoss;
            gardient( h_a,  t_a,  r_a,  h_b,  t_b,  r_b);//更新梯度

Train train;
void prepare()
    int mycount=0;//记录读取进度
    FILE* f1 = fopen("../data/FB15k/entity2id.txt", "r");
    FILE* f2 = fopen("../data/FB15k/relation2id.txt", "r");
    int x;
    while (fscanf(f1, "%s%d", buf, &x) == 2)//==2指的是正确读入的参数个数
        if (mycount % 200 == 0)
            cout << "读取第"<::iterator it = left_entity[i].begin(); it != left_entity[i].end(); it++)

            sum2 = sum2 + it->second;
        left_num[i] = sum2 / sum1;
    for (int i = 0; i < relation_num; i++)
        double sum1 = 0, sum2 = 0;
        for (map::iterator it = right_entity[i].begin(); it != right_entity[i].end(); it++)
            sum1++; sum2 = sum2 + it->second;
        right_num[i] = sum2 / sum1;

    cout << "relation_num=" << relation_num << endl;
    cout << "entity_num=" << entity_num << endl;

int ArgPos(char *str, int argc, char **argv)
    int i;
    for ( i = 1; i < argc; i++)
        if (!strcmp(str, argv[i]))//若两者相同
            if (i == argc - 1)
                cout << "Argument missing for " << str << endl;
            return i;
    return -1;
int main(int argc, char **argv)
    //D:\codes\vs\kb2e\Debug>kb2e.exe -size 111 -margin 22 -method 3

    int method = 1;//1表示伯努利采样
    int n = 100;//dim
    double rate = 0.001;//lr
    double margin = 1;
    int i;

    if ((i = ArgPos((char *)"-size", argc, argv)) > 0) n = atoi(argv[i + 1]);//atoi 字符串转int
    if ((i = ArgPos((char *)"-margin", argc, argv)) > 0) margin = atoi(argv[i + 1]);
    if ((i = ArgPos((char *)"-method", argc, argv)) > 0) method = atoi(argv[i + 1]);
    cout << "dim="<#include
using namespace std;

//bool debug = false;
bool L1_flag = 1;//L1范数,0表示L2
string version;
char buf[100000], buf1[100000];
int entity_num, relation_num;
int n = 100;
map relation2id, entity2id;
mapid2relation, id2entity;

double vec_len(vectora)
    double res = 0;
    for (int i = 0; i < a.size(); i++)
        res += a[i]*a[i];
    return sqrt(res);
double sqr(double x)
    return x*x;
double cmp(pair a, pairb)
    return a.second < b.second;
class Test{

    vector> relation_vec, entity_vec;
    //vectorh, r, t;
    vectorfb_h, fb_r, fb_t;
    map, map >ok;
    double res;

    void add(int h, int t, int r, bool flag)
        if (flag)
            ok[make_pair(h, r)][t] = 1;
    double cal_sum(int h, int t, int r)
        double sum = 0;
        if (L1_flag)//L1
        for (int i = 0; i < n; i++)
            sum += fabs(entity_vec[h][i] + relation_vec[r][i] - entity_vec[t][i]);
        for (int i = 0; i < n; i++)
            sum += sqr(entity_vec[h][i] + relation_vec[r][i] - entity_vec[t][i]);
        return sum;
    int rand_max(int x)
        int res = (rand()*rand()) % x;
        if (res<0)
            res += x;
        return res;
    void run()
        FILE* f1 = fopen(("relation2vec." + version).c_str(), "r");
        FILE* f3 = fopen(("entity2vec." + version).c_str(), "r");
        cout <<"relation_num="<< relation_num << ', ' << "entity_num="<1e-3)
                cout << "wrong_entity" << i << ' ' << vec_len(entity_vec[i]) << endl;
        fclose(f1); fclose(f3);
        //map rel_num;//relationid,number

        double hrank = 0,hrank_filter=0;//替换头实体排名
        double hrank10num = 0,hrank10numfilter=0;//替换头实体rank10
        double m = 0;//正确样本个数 用于filter

        for (int testid = 0; testid < fb_h.size(); testid++)
            int h = fb_h[testid];//head_entity id
            int t = fb_t[testid];
            int rel = fb_r[testid];
            //rel_num[rel] += 1;
            for (int i = 0; i < entity_num; i++)
                double score=cal_sum(i, t, rel);//头实体被每个实体替代
                a.push_back(make_pair(i, score));
            sort(a.begin(), a.end(), cmp);//升序排序

            m = 0;
            for (int i = a.size() - 1; i >= 0; i--)
                if (ok[make_pair(a[i].first, rel)].count(t) > 0)//存在正确样本

                if (a[i].first == h)//正确样本
                    hrank += a.size() - i;//raw 排名
                    hrank_filter += a.size() - i - m;
                    if (a.size() - i < 10)
                        hrank10num += 1;
                    if (a.size() - i - m < 10)
                        hrank10numfilter += 1;

        cout << "替换头实体raw mean rank=" << hrank / fb_h.size();
        cout << "替换头实体raw rank10=" << hrank10num / fb_h.size();
        cout << "替换头实体filter mean rank=" << hrank_filter / fb_h.size();
        cout << "替换头实体filter rank10=" << hrank10numfilter / fb_h.size();



Test test;
void prepare()
    FILE* f1 = fopen("../data/FB15k/entity2id.txt","r");
    FILE* f2 = fopen("../data/FB15k/relation2id.txt", "r");
    int x;
    while (fscanf(f1, "%s%d", buf, &x) == 2)
        string s = buf;
        entity2id[s] = x;
        id2entity[x] = s;
    while (fscanf(f2, "%s%d", buf, &x) == 2)
        string s = buf;
        relation2id[s] = x;
        id2entity[x] = s;
    FILE* f_kb = fopen("../data/FB15k/test.txt", "r");
    while (fscanf(f_kb, "%s", buf) == 1)
        string s1 = buf;//h
        fscanf(f_kb, "%s", buf);
        string s2 = buf;//t
        fscanf(f_kb, "%s", buf);
        string s3 = buf;//r
        if (entity2id.count(s1)==0)
            cout << "miss entity:" << s1 << endl;
        if (entity2id.count(s2) == 0)
            cout << "miss entity:" << s2 << endl;
        if (relation2id.count(s3) == 0)
            cout << "miss relation:" << s3 << endl;
            relation2id[s3] = relation_num;

    FILE* f_kb1 = fopen("../data/FB15k/train.txt", "r");
    while (fscanf(f_kb1, "%s", buf) == 1)
        string s1 = buf;
        fscanf(f_kb1, "%s", buf);
        string s2 = buf;
        fscanf(f_kb1, "%s", buf);
        string s3 = buf;
        if (entity2id.count(s1) == 0)
            cout << "miss entity:" << s1 << endl;
        if (entity2id.count(s2) == 0)
            cout << "miss entity:" << s2 << endl;
        if (relation2id.count(s3) == 0)
            relation2id[s3] = relation_num;
        test.add(entity2id[s1], entity2id[s2], entity2id[s3], true);

    FILE* f_kb2 = fopen("../data/FB15k/valid.txt", "r");
    while (fscanf(f_kb2, "%s", buf) == 1)
        string s1 = buf;
        fscanf(f_kb2, "%s", buf);
        string s2 = buf;
        fscanf(f_kb2, "%s", buf);
        string s3 = buf;
        if (entity2id.count(s1) == 0)
            cout << "miss entity:" << s1 << endl;
        if (entity2id.count(s2) == 0)
            cout << "miss entity:" << s2 << endl;
        if (relation2id.count(s3) == 0)
            relation2id[s3] = relation_num;
        test.add(entity2id[s1], entity2id[s2], relation2id[s3], true);

int main(int argc,  char** argv)
    if (argc < 2)
        return 0;
        version = argv[1];

