谷歌采样修正的双塔模型

贡献

本文提出了一种从流式数据中估计item频率的新算法,通过理论推导,证明了该算法可以在无需固定item词表的情况下生效,并且能够产生无偏估计,同时能够适应item分布的变化。以解决热门商品在负样本采样时,采样次数过多而被过度惩罚。

业内的主流方法和问题

推荐领域中emb学习的挑战通常有两个:1)对于许多工业级别的应用来说item语料规模会相当大。2)采集自用户反馈的训练数据对许多item来说非常稀疏,从而导致模型预测的长尾内容有很大的方差。面对商品冷启动问题,现实世界的系统需要适应数据分布的变化,以更好地获取新鲜item。

双塔网络算法原理

双塔网络与NCF(神经协同过滤)不同,双塔网络降低耗时且修正了模型损失函数。NCF简介
利用双塔模型构架推荐系统,首先建立两个参数embedding函数,把query和候选item映射到k维向量空间,模型的输出为二者的embedding内积。
模型结构如图所示:
谷歌采样修正的双塔模型_第1张图片

In-batch loss function

推荐问题可以看作是,给定query X,从M个item中得到y的概率可以利用softmax函数计算:谷歌采样修正的双塔模型_第2张图片
考虑反馈 ri, 加权对数似然损失函数为:
谷歌采样修正的双塔模型_第3张图片
当M非常大(样本总数很大)时,我们通常可以利用负采样算法进行计算。然而对于流数据,我们考虑在同一个batch中采样负样本。此处可看成是现实场景下,训练数据是分批到达,模型训练也是分batch进行。
batch-softmax函数为:
谷歌采样修正的双塔模型_第4张图片
在每个batch中,由于存在幂律分布现象。如果在每个batch中随机采样负样本,会使热门商品更容易被采样到,在损失函数中就“过度”惩罚了这些热门商品,因此考虑用频率对采样进行修正,即:
在这里插入图片描述
其中 Pj 是在每个batch中随机采样到item j的概率(将在下一节中介绍),因此修正后的条件概率函数为:谷歌采样修正的双塔模型_第5张图片

Streaming Frequency Estimation

此方法用于估计在流数据中,每个batch下item出现的概率。上面提到的Pj。

对于一个流式的随机batch,问题是预估每个batch中每个item 的出现概率。一个关键的设计准则是当有多个训练jobs(workers)时,要有一个全局的预估来支持分布式训练。此处可以利用全局step,并对一个item的频率预估转化为deta预估,其表示为两次连续命中item所需的平均step。例如,如果一个item每50step采样一次,deta = 50,则得到p = 0.02。使用全局step提供了两点优势:1)通过读取和修改全局step,多个worker在频率预估中隐式的同步。2)预测通过简单的滑动平均来更新,该更新适用于分布的改变。

为了解决hash collision的问题,可以建立多个数组 Ai Bi 最终在多个数组中取最大。

定义两个大小为H的数组A,B,哈希函数h可以把每个item映射为[H]内的整数。

A[h(y)]表示item y上次被采样到的时刻
B[h(y)]表示每多少步item y可以被采样一次
先说结论,当第t步y被采样到时,利用迭代可更新A,B:
谷歌采样修正的双塔模型_第6张图片
alpha 可看作学习率。通过上式更新后,则在每个batch中item y出现的概率为 1/B[h(y)]。
直观上,上式可以看作利用SGD算法和固定的学习率 [公式] 来学习“可以多久被采样到一次”这个随机变量的均值。

下面,可以从数学理论上证明这种迭代更新的有效性:
谷歌采样修正的双塔模型_第7张图片谷歌采样修正的双塔模型_第8张图片

算法总结

涵盖了In-batch loss function 和 流数据频率估计 的训练算法
谷歌采样修正的双塔模型_第9张图片
流数据频率估计 算法
谷歌采样修正的双塔模型_第10张图片
改进的多元数组-频率估计算法
为了解决hash collision的问题,可以建立多个数组 Ai Bi 最终在多个数组中取最大。
谷歌采样修正的双塔模型_第11张图片

归一化&&微调

谷歌采样修正的双塔模型_第12张图片

参考:https://www.jianshu.com/p/177f49effd50

你可能感兴趣的:(谷歌采样修正的双塔模型)