Tensorflow的负采样函数Sampled softmax loss学习笔记

最近阅读了YouTube的推荐系统论文,在代码实现中用到的负采样方法我比较疑惑,于是查了大量资料,总算能够读懂关于负采样的一些皮毛。
本文主要针对tf.nn.sampled_softmax_loss这个函数进行讲解,并引申一些数学意义,偏实用性。

类别和标签

在阅读代码的过程中发现一个问题,就是代码作者先是对样本打y,打y的依据是即user最后时刻观看的视频是否正确,那么这样的话,很明显没有负样本,而且最终的损失函数用的是tf.nn.sampled_softmax_loss,很明显是多分类,那么打二分类标签且没有负样本的依据是什么?
这里就涉及到类别和标签的区别。
以MNIST图像分类为例,标签是“该图像中的数字是否是5”,同理,就会有标签“该图像中的数字是否是6”,那么就有10个标签,每个标签都会有两个类别,即是或者不是,有读者可能就看出来,这不就是one-hot编码的逻辑嘛,事实就是这样,MNIST数据集训练时用到的损失函数是交叉熵函数,这时的y非0即1,那么one-hot编码可以理解为将多分类转换成了二分类。那如果我们不用交叉熵损失函数呢,用tf.nn.sampled_softmax_loss函数,看以下代码,

tf.nn.sampled_softmax_loss(weights, # Shape (num_classes, dim)     - floatXX
                     biases,        # Shape (num_classes)          - floatXX 
                     labels,        # Shape (batch_size, num_true) - int64
                     inputs,        # Shape (batch_size, dim)      - floatXX  
                     num_sampled,   # 负采样个数- int
                     num_classes,   # 类别数量- int
                     num_true=1,    
                     sampled_values=None,
                     remove_accidental_hits=True,
                     partition_strategy="mod",
                     name="sampled_softmax_loss")

我说明一下其中最重要的几个参数:
label——标签,如果我们用多分类的思路来做MNIST分类,那么标签就是“图像中的数字是5”,此时,label=[[5]],同理有10个标签,代码中label的shape的第二个维度是num_true,也就是说这个图像中的数字只可能是5,如果还可能是3的话,那么num_true就是2了,label也就是[[5, 3]],我们称第一种情况为多标签单分类第二种情况为多标签多分类

num_sampled——这个参数的意思是选取样本个数组成一个样本子集,这里后面会用公式说明。

num_classes——标签数量,这里是10。
多分类的本质:即遍历众多样本,每次遍历从中选取一个正确的类别(目标类别)。(自己的理解,有错误欢迎指出)
经过以上分析,我开头抛出的问题应该已经有了答案,因为都是正样本所有y是1,但是他们的标签不同,所以这个问题是一个多标签单分类的问题。

数学分析

首先贴一张图,是TensorFlow官网给出的tf.nn.sampled_softmax_loss的解释中的一张图
Tensorflow的负采样函数Sampled softmax loss学习笔记_第1张图片很明显,tf.nn.sampled_softmax_loss属于最后一行,
下面我再贴一张图,是我认为解读的还算通俗的一种解释,这是原博地址,
Tensorflow的负采样函数Sampled softmax loss学习笔记_第2张图片候选集C就是上面代码中num_sampled参数选择出来的,目标分类就是我上面总结的多分类的本质中提到的,举个例子,一个候选集中有5,3,4,2这4个数字的图像,我想挑出5,那么5这个类别就是目标类别。
由于本人统计学不是很好,公式只能勉强理解,如果大家对这个函数的公式推导感兴趣,我列出几个参考文献,大家可以去学习一下。

参考文献

[1]sampled softmax与其在框架中的使用
[2]nce loss 与 sampled softmax loss 到底有什么区别?怎么选择?
[3]Tensorflow的采样方法:candidate sampling
[4]tf.nn.sampled_softmax_loss用法详解

你可能感兴趣的:(机器学习,深度学习,人工智能,python,算法)