Embedding算法Line源码简读

说明

其主页地址为:https://github.com/tangjianpku/LINE。
里面有详细的介绍说明,此文章只是中文简单记录。比如:

line.cpp, the souce code of the LINE;

reconstruct.cpp, the code used for reconstructing the sparse networks into dense ones, which is described in Section 4.3;

normalize.cpp, the code for normalizing the embeddings (l2 normalization);

concatenate.cpp, the code for concatenating the embeddings with 1st-order and 2nd-order;

先看源码说明,line.cpp是算法主要文件,reconstruct.cpp是用于将稀疏网络重建为密集网络的源文件等,这里主要写下line里面的一些函数。

main函数

首先从main函数开始往下看。ArgPos读参数,然后转入TrainModel()进行训练。还有就是,读别人的代码是非常痛苦的一件事情,不过作者代码还是比较规范,没有没有意义的函数名,很多函数一看名称就知道个大概,还有就是关键地方写了注释。

TrainLINE函数

InitHashTable();初始化哈希表

ReadData();读取文件数据,并把它作为(v,u,w)表示,u、u是vertex,w是权重

InitAliasTable();初始化偏倚表,可能翻译不太对,文章里面有说,为了提高效率及performance,解决长链传导过程中导致参数很大的越来越大,越小的会逐渐消失问题,类似深度学习的梯度消失和爆炸问题,文章将采用带偏倚的采样,权重越大概率越高。

InitVector();初始化相关向量。

InitNegTable();初始化负采样表。

InitSigmoidTable();初始化sigmoid表,后面可以以查找的形式直接取出值。

Update函数

个人觉得这个函数是最重要的

void Update(real *vec_u, real *vec_v, real *vec_error, int label)
{
    real x = 0, g;
    for (int c = 0; c != dim; c++) x += vec_u[c] * vec_v[c];#算出v、u之间的点积,可以理解为相似度
    g = (label - FastSigmoid(x)) * rho;#用sigmoid函数算出v、u之间相似的概率0-1,点积越大越接近1。与label减以后得到误差,乘以学习率rho是学习多少误差,即对误差更新多少
    for (int c = 0; c != dim; c++) vec_error[c] += g * vec_v[c];#误差累积量
    for (int c = 0; c != dim; c++) vec_v[c] += g * vec_u[c];#根据误差来更新v向量
}

其他函数

有的函数不重要,知道功能就行,但是想要深入理解,还得细读负采样及重要性边采样两个地方,其他函数简单说明如下。

unsigned int Hash(char *key)根据key得到一个hash值,个人跑了代码,得到的是ASCII码,细节不谈了

void InitHashTable(),void InsertHashTable(char *key, int value),int SearchHashTable(char *key)顾名思义,初始化哈希表,向哈希表插入,以及根据Key搜索。

int AddVertex(char *name)这个是非常重要的,向vertex set中添加一个vertex

void ReadData()根据文件生成所需格式(u,v,w)的文件,读一条边即可知道u,v,w三个参数。

void InitAliasTable()初始化偏倚表,用来采样一条边,得到一条边后便得到了相应顶点及权重。

long long SampleAnEdge(double rand_value1, double rand_value2)采样一条边的具体实现

void InitVector()初始化u、v向量。

void Output()得到的结果以文件形式保存到某个地方。

其他就不说了,有空再来细读,但是读别人代码真的慢,特别是对c++还不熟。我的方式是一个函数一个函数debug试一试看具体功能,如果后面再细跑了代码再来纠正或者另外写点解读。

你可能感兴趣的:(c++)