Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析

文章目录

  • 官方代码
  • 数据集下载
  • preprocess
  • utils
    • init
    • generate_batch
    • get_slice
  • Model
    • Get item embedding
      • GNN forward
      • SessionGraph forward
    • Get Session embedding
  • 模型评估

Session-based Recommendation with Graph Neural Networks
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第1张图片

官方代码

代码仓库:https://github.com/CRIPAC-DIG/SR-GNN#paper-data-and-code
有torch和tf两个版本

数据集下载

一个是YOOCHOOSE,仓库的链接失效了,去kaggle下就好,https://www.kaggle.com/chadgostopp/recsys-challenge-2015
另一个是DIGINETICA,https://competitions.codalab.org/competitions/11161
下载完后解压到datasets里面
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第2张图片
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第3张图片
最终就用到这两个文件,把其它无关的删掉就好
代码自带了一个 sample_train-item-views.csv 小数据集,用来验证代码正确性的,这个跑过了没问题

改一下这个文件头,让代码可以读取
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第4张图片
这个train-item-views就是session、user、item、timframe、eventdate五个部分,userId里面很多项都是NA

然后接下来就可以训练了

preprocess

预处理数据集,这个文件从原始的数据集中划分出训练集、测试集
这个文件主要做了几个处理:

  1. 对物品id进行了重构,使物品的id从1开始

  2. 为了消除bias,两个过滤:
    过滤掉所有session累计点击数小于5的物品,过滤掉点击物品数小于2的序列

  3. 为了扩展数据集的数量,这里做了一下session拆分,处理方法如下:
    对于[1,3,66,5]的session,我们可以逐步划分为:

    [1] target = 3
    [1,3] target = 66
    [1,3,66] target = 5

    这三个session

做了三个处理后,用pickle写入文件中

Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第5张图片

utils

这个文件提供了一些辅助功能用以处理数据
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第6张图片

init

首先是初始化调用 data_masks,它把slice中每个session用0填充到固定长度,由于我们给物品重新编号过,第一个物品的id从1开始,所以下标为0表示没有物品。填充到固定长度是为了方便计算,同时返回对应的mask session,它也是一个序列,里面真实的物品值为1,填充物品为0

generate_batch

generate_batch划分数据集并产生对应的slice

get_slice

get_slice接受一个slice,然后对于这个slice对应的所有session,它返回以下信息

  1. alias_input
    session中物品对应于去重排序后的 session 的 id
  2. items
    存储去重排序并补全后的session,和alias_input搭配使用查找A中的内容。
    items的语义是每个session的物品空间,并不考虑时效性,所以是排序去重补全后的序列,我们用它来生成物品的embedding。
  3. A
    session对应的出入度矩阵,经过归一化处理
  4. mask
    填充后的session,item id 换为 1,填充位为0
  5. target
    next item 的 id

这里解释一下 alias_input、A和session的具体含义
假设一个session是[5,11,7,0,0,0],那么它去重排序后对应的序列是[0,5,7,11]
alias_input的值是[1,3,2,0,0,0]
以5为例,alias_input中5的值是1,对应去重排序后的序列中第1个物品

为什么要取得alias_input呢,因为我们对每个session都生成一个出入度矩阵A,以出度矩阵为例:

[[0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]

这里A[1,3] = 1,A[3,2] = 1,恰好是alias_input里面的值生成的出度矩阵A
为什么要这么绕,因为A的下标并不是以序列第一个物品的id 5开始的,而是从0开始,所以要用一个额外的序列来帮助标记


Model

Get item embedding

搞清楚上面的逻辑之后,我们就可以开始训练了

首先我们不看model的定义,先看训练过程
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第7张图片
这里的步骤很清晰
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第8张图片
重点是里面的forward
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第9张图片
第一行就是熟悉的get_slice了,得到数据之后,我们使用 hidden = model(items, A) 获得隐藏信息,这一个语句就是在调用model的forward
model一共有两个,第二个SessionGraph是集成了GNN的model
在这里插入图片描述

GNN forward

这里是调用SessionGraph的forward,里面有一个self.gnn,就是GNN的forward
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第10张图片
GNN的forward如下
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第11张图片
传入的A和hidden分别是这个slices对应的矩阵A和item的embedding,然后我们看计算

在这里插入图片描述
在这里插入图片描述
为什么是input_in、input_out两个结果呢,这是因为原始的A是 n ∗ 2 n n*2n n2n的矩阵,第一个 n ∗ n n*n nn的矩阵是入度矩阵,第二个是出度矩阵
这里A的shape是100x6x12100是这个slice涉及100条session,后面的6x12是每个session生成的矩阵A
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第12张图片
对于A[:, :, :A.shape[1]]我们取得出度矩阵,A[:, :, A.shape[1]: 2 * A.shape[1]]取得入度矩阵(至于代码为什么是反过来的input_in、input_out没看懂。。)
在这里插入图片描述
这一步拼接两个计算结果变成inputs

Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第13张图片
这一块对应论文的2、3、4
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第14张图片
上面代码里面gi是什么不太清楚
论文里提到z是reset gate,r是update gate,但是这里两个却是resetgateinputgate。。

在这里插入图片描述
最后返回结果对应公式5

不过为什么是用inputgate来乘呢。。reset gate z s , i t z^t_{s,i} zs,it去哪了。。

这里有几个不清楚的地方

  1. gi、gh分别是什么
  2. inputgate、resetgate分别是什么,和论文的reset gate、update gate似乎对应不上
  3. hy对应的是公式5吗?但是为什么计算方法不一样?

翻了一下issue,发现不止我一个人有这样的问题:
https://github.com/CRIPAC-DIG/SR-GNN/issues/35
https://github.com/CRIPAC-DIG/SR-GNN/issues/49
https://github.com/CRIPAC-DIG/SR-GNN/issues/15
emmmm… 看了下两位作者也意识到了代码的错误,但是后面并没有改正过来

总之这里GNN forward的返回结果得到了
对应论文下面画框的模块
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第15张图片
总结一下这部分做的工作

  1. 首先传递items,A两个变量给SessionGraph,然后调用forward
  2. items使用embedding得到hidden变量,然后传递给GNN模型,调用forward
  3. GNN里面做的工作是学习item embedding,对于矩阵A和items的embedding结果,它返回一个hidden结果

至此我们通过GNN得到item的 embedding

SessionGraph forward

Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第16张图片
就是返回刚刚的结果

Get Session embedding

对应论文里面的公式6和公式7
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第17张图片
意思是用 item 的 embedding 用来表示 session

get的含义是获取 item 的embedding,下面一个seq_hidden是生成session的表示
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第18张图片

注意力机制在哪? 没看到

然后我们这里返回的是targetsmodel.compute_score(...)

下面看一下compute_scoreSession-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第19张图片
对应的是公式8
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第20张图片
和模型图中
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第21张图片

计算后返回到train_test
Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第22张图片
计算都完毕了,可以计算loss了。
至此就是一个完成的训练过程

模型评估

Session-based Recommendation with Graph Neural Networks,SR-GNN代码分析_第23张图片
是train_test的后面部分


简单过了一下官方的SR-GNN代码,发现几个和论文对应不上的错误
大致学习了一下处理session的方法和整个model的训练流程
关于GNN和GGNN的文章还是要多看

你可能感兴趣的:(论文阅读)