代码仓库:https://github.com/CRIPAC-DIG/SR-GNN#paper-data-and-code
有torch和tf两个版本
一个是YOOCHOOSE,仓库的链接失效了,去kaggle下就好,https://www.kaggle.com/chadgostopp/recsys-challenge-2015
另一个是DIGINETICA,https://competitions.codalab.org/competitions/11161
下载完后解压到datasets里面
最终就用到这两个文件,把其它无关的删掉就好
代码自带了一个 sample_train-item-views.csv 小数据集,用来验证代码正确性的,这个跑过了没问题
改一下这个文件头,让代码可以读取
这个train-item-views就是session、user、item、timframe、eventdate五个部分,userId里面很多项都是NA
然后接下来就可以训练了
预处理数据集,这个文件从原始的数据集中划分出训练集、测试集
这个文件主要做了几个处理:
对物品id进行了重构,使物品的id从1开始
为了消除bias,两个过滤:
过滤掉所有session累计点击数小于5的物品,过滤掉点击物品数小于2的序列
为了扩展数据集的数量,这里做了一下session拆分,处理方法如下:
对于[1,3,66,5]的session,我们可以逐步划分为:
[1] target = 3
[1,3] target = 66
[1,3,66] target = 5
这三个session
做了三个处理后,用pickle写入文件中
首先是初始化调用 data_masks,它把slice中每个session用0填充到固定长度,由于我们给物品重新编号过,第一个物品的id从1开始,所以下标为0表示没有物品。填充到固定长度是为了方便计算,同时返回对应的mask session,它也是一个序列,里面真实的物品值为1,填充物品为0
generate_batch划分数据集并产生对应的slice
get_slice接受一个slice,然后对于这个slice对应的所有session,它返回以下信息
这里解释一下 alias_input、A和session的具体含义
假设一个session是[5,11,7,0,0,0],那么它去重排序后对应的序列是[0,5,7,11]
alias_input的值是[1,3,2,0,0,0]
以5为例,alias_input中5的值是1,对应去重排序后的序列中第1个物品
为什么要取得alias_input呢,因为我们对每个session都生成一个出入度矩阵A,以出度矩阵为例:
[[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]]
这里A[1,3] = 1,A[3,2] = 1,恰好是alias_input里面的值生成的出度矩阵A
为什么要这么绕,因为A的下标并不是以序列第一个物品的id 5开始的,而是从0开始,所以要用一个额外的序列来帮助标记
搞清楚上面的逻辑之后,我们就可以开始训练了
首先我们不看model的定义,先看训练过程
这里的步骤很清晰
重点是里面的forward
第一行就是熟悉的get_slice了,得到数据之后,我们使用 hidden = model(items, A)
获得隐藏信息,这一个语句就是在调用model的forward
model一共有两个,第二个SessionGraph是集成了GNN的model
这里是调用SessionGraph的forward,里面有一个self.gnn,就是GNN的forward
GNN的forward如下
传入的A和hidden分别是这个slices对应的矩阵A和item的embedding,然后我们看计算
为什么是input_in、input_out两个结果呢,这是因为原始的A是 n ∗ 2 n n*2n n∗2n的矩阵,第一个 n ∗ n n*n n∗n的矩阵是入度矩阵,第二个是出度矩阵
这里A的shape是100x6x12
,100
是这个slice涉及100条session,后面的6x12
是每个session生成的矩阵A
对于A[:, :, :A.shape[1]]
我们取得出度矩阵,A[:, :, A.shape[1]: 2 * A.shape[1]]
取得入度矩阵(至于代码为什么是反过来的input_in、input_out没看懂。。)
这一步拼接两个计算结果变成inputs
这一块对应论文的2、3、4
上面代码里面gi是什么不太清楚
论文里提到z是reset gate,r是update gate,但是这里两个却是resetgate
和inputgate
。。
不过为什么是用inputgate来乘呢。。reset gate
z s , i t z^t_{s,i} zs,it去哪了。。
这里有几个不清楚的地方
翻了一下issue,发现不止我一个人有这样的问题:
https://github.com/CRIPAC-DIG/SR-GNN/issues/35
https://github.com/CRIPAC-DIG/SR-GNN/issues/49
https://github.com/CRIPAC-DIG/SR-GNN/issues/15
emmmm… 看了下两位作者也意识到了代码的错误,但是后面并没有改正过来
总之这里GNN forward的返回结果得到了
对应论文下面画框的模块
总结一下这部分做的工作
items,A
两个变量给SessionGraph,然后调用forward至此我们通过GNN得到item的 embedding
对应论文里面的公式6和公式7
意思是用 item 的 embedding 用来表示 session
get的含义是获取 item 的embedding,下面一个seq_hidden是生成session的表示
注意力机制在哪? 没看到
然后我们这里返回的是targets
和model.compute_score(...)
下面看一下compute_score
对应的是公式8
和模型图中
计算后返回到train_test
中
计算都完毕了,可以计算loss了。
至此就是一个完成的训练过程
简单过了一下官方的SR-GNN代码,发现几个和论文对应不上的错误
大致学习了一下处理session的方法和整个model的训练流程
关于GNN和GGNN的文章还是要多看