RecSys 2019:对in-batch负采样进行bias校正的Google双塔模型

文章目录

  • 1.总览
  • 2.考虑到bias的softmax损失修正
  • 3.如何计算batch内item的采样概率?
  • 4.其他的一些tricks梳理
    • 4.1 Embedding标准化
    • 4.2 softmax增强
  • 5.代码

论文链接:Sampling-Bias-Corrected Neural Modeling for Large Corpus Item Recommendations

1.总览

目前,业界的推荐系统可以分成Retrieval和Ranking两个阶段,Retrieval需要从百万级以上的item库中召回到千级item作为排序模型的一路输入。而双塔模型是业界主流的召回方法,比如论文中给到的模型即为Youtube当时最新的召回模型:
RecSys 2019:对in-batch负采样进行bias校正的Google双塔模型_第1张图片
这个模型在整体上就是最普通的双塔。左边是user塔,输入包括两部分,第一部分seed是user当前正在观看的视频,第二部分user的feature是根据user的观看历史计算的,比如说可以使用user最近观看的k条视频的id emb均值,这两部分融合起来一起输入user侧的DNN。右边是item塔,将候选视频的feature作为输入,计算item的 embedding。之后再计算相似度,做排序就可以了。
我们知道召回的过程可以视为一个多分类问题,因此,模型的输出层会选择softmax计算后再计算交叉熵损失。但当检索空间特别大时(也就是Item特别多,比较成熟的App通常都会面对100M,甚至更大),softmax函数的计算耗时是不能接受的,所以通常的想法,从全量的Items中采样出一个Batch,然后在这个Batch上计算softmax。也就是我们所说的负采样。
考虑到业界中实际的数据是基于流数据的,每一天都会产生新的训练数据。因此,负样本的选择只能在batch内进行,batch内的所有样本作为彼此的负样本去做batch softmax。这种采样的方式带来了非常大的bias。一条热门视频,它的采样概率更高,因此会更多地被当做负样本,这不符合实际。因此这篇工作的核心就是减小batch内负采样带来的bias。

2.考虑到bias的softmax损失修正

对于热门item,它在一个batch中有更大的概率被采样到,这会导致embedding的更新更偏向于热门item,加重长尾分布数据下的马太效应。所以一个直观的想法是惩罚热门item的softmax概率:
RecSys 2019:对in-batch负采样进行bias校正的Google双塔模型_第2张图片
s(x,y)是user塔的embedding输出和item塔embedding输出的内积,用于衡量二者的相似度,对于经典softmax多分类,我们可以得到下面的softmax计算式和优化似然函数:
RecSys 2019:对in-batch负采样进行bias校正的Google双塔模型_第3张图片
那么如何对热门item的softmax进行修正呢?很简单,在每个item和user计算得到的内积基础上,加上一个log(item被采样概率)的惩罚项,热门item的p_j更大。新的s(x,y)可以缩小计算softmax时的bias,后续可以用梯度下降or其他优化算法训练模型。
RecSys 2019:对in-batch负采样进行bias校正的Google双塔模型_第4张图片

3.如何计算batch内item的采样概率?

这部分主要对采样概率进行估计,这里的核心思想是假设某视频连续两次被采样的平均间隔为B,那么该视频的采样概率即为1/B,如果该商品上一次被采样的时刻为A的话,那么当该商品在时刻t被采样时,则利用A辅助更新B:
RecSys 2019:对in-batch负采样进行bias校正的Google双塔模型_第5张图片
式中的函数h()是一个hash函数,他将某个视频的id映射到具体的索引上,然后利用该索引从矩阵B和矩阵A中分别得到该商品对应的平均采样间隔和上一次该商品被采样的时刻,从而进行梯度更新。当B更新完之后,需要对A进行更新(将时刻t赋值给A)。

既然是hash过程,当H

RecSys 2019:对in-batch负采样进行bias校正的Google双塔模型_第6张图片
论文中,作者对哈希方程的数量做了实验,以严重单hash和multi hash的优劣:
RecSys 2019:对in-batch负采样进行bias校正的Google双塔模型_第7张图片
可以看到,使用更多的Hash方程数量,误差越小,也就是说多hash的方式在论文的实验中呈现更好的效果。

4.其他的一些tricks梳理

4.1 Embedding标准化

即对user塔和item塔的输出embedding进行L2标准化,实践证明这是个工程上的tricks:
在这里插入图片描述
在这里插入图片描述

4.2 softmax增强

就是对于内积计算的结果,除以一个固定的超参
在这里插入图片描述
除以超参的效果如下,可以看到softmax的效果更加明显(sharpen)
RecSys 2019:对in-batch负采样进行bias校正的Google双塔模型_第8张图片
超参的设定可以通过实验结果的召回率或者精确率进行微调。其实这个trick在其他方法中也很常见,比如知识蒸馏中的带温度的softmax,和Transformer中,计算Q和V内积softmax后的scale,本质上都是让softmax分类学习的更充分。

5.代码

关于这篇论文有个开源的基于tf2.0的复现,贴上链接:
代码链接

参考:

  1. Sampling-Bias-Corrected Neural Modeling for Large Corpus Item Recommendations
  2. RS Meet DL(72)-[谷歌]采样修正的双塔模型

你可能感兴趣的:(搜推广,推荐系统,算法)