Tensorflow小技巧整理:tf.multinomial()采样

tf.multinomial()

做生成任务时,得到 decoder 最终的输出之后,就需要决策选如何利用得到的输出张量进行生成。tf.argmax()是最简单最粗暴的一种方法,直接选取概率最大的词汇作为输出。beam search 等算法的出现,使得生成的结果有了更多的可能性。最近看到一段代码,使用的是 tf.multinomial() 进行采样,也尝试用了一下。

tf.multinomial(logits, num_samples, seed=None, name=None)

logits是一个二维张量,num_samples指的是采样的个数。其实很好理解,我们生成每个时刻的 logits 时,输出维度应该是 [ batch_size, vocab_size ] 形式的,代表着该时刻,每一个batch对应的词典中各词汇生成的概率。tf.multinomial() 将按照该概率分布进行采样,返回的值是 logits 第二维上的 id,也就是我们需要的字典的 id。
举个例子:

比如每次将从5个候选词汇中采样,概率分布如图所示,采样个数为100,统计一下结果如下:
Tensorflow小技巧整理:tf.multinomial()采样_第1张图片
可以看到,第一个词和最后一个词的采样次数会高很多,而概率为 0.05 的第二个词和第三个词则很少被采样到。如果5个词概率相同:
则我们的采样结果为:
Tensorflow小技巧整理:tf.multinomial()采样_第2张图片
可以看到,每个词所被采到的次数大致是相等的。 在实际生成中,一个训练良好的模型,会大概率生成效果与 argmax() 采样结果一致,但也有一定的几率生成概率较低的词汇,从而也能够改善最终生成的效果。

你可能感兴趣的:(tensorflow使用整理)