https://www.tensorflow.org/api_docs/python/tf/random/fixed_unigram_candidate_sampler
上面链接是官网解释,看了一会儿感觉没看懂 跑了几个列子有点懂了。
本文结合https://www.w3cschool.cn/tensorflow_python/tf_nn_fixed_unigram_candidate_sampler.html
给出更详细的解释如下:
tf.random.fixed_unigram_candidate_sampler(
true_classes,
num_true,
num_sampled,
unique,
range_max,
vocab_file='',
distortion=1.0,
num_reserved_ids=0,
num_shards=1,
shard=0,
unigrams=(),
seed=None,
name=None
)
使用提供的(固定)基本分布对一组类进行采样.
该操作从整数范围[0,range_max]中随机采样num_sampled个类,所有的类的类别是[0, range_max), 每个类被采样的概率大小由参数unigrams指定,这个参数的值可以是概率的array,也可以是int的vector(表示出现次数,次数大表示被采样的概率大)
sampling_candidates的元素是在没有替换 (如果unique = True) 或替换 (如果unique = False) 的基础分布中绘制的.
基本分布从文件中读取或作为内存中数组传入.还可以通过对权重应用distortion power(失真功率)来扭曲分布.
此外,此操作返回张量true_expected_count和sampled_expected_count,表示每个目标类(true_classes)和采样类(sampled_candidates)预期在平均张量的采样类中出现的次数.如果unique=True,则这些是拒绝后的概率,我们大致计算它们.
参数:
返回:
测试例子
import tensorflow as tf
def test1():
vec = tf.constant([[1, 2, 3, 4, 6]], dtype=tf.int64)
# vec = tf.reshape(vec, [-1, 1])
ids, _, _ = tf.nn.fixed_unigram_candidate_sampler(
true_classes=vec,
num_true=5,
num_sampled=2,
unique=False,
range_max=5,
vocab_file='',
distortion=1.0,
num_reserved_ids=0,
num_shards=1,
shard=0,
unigrams=(0.1, 0.2, 0.3, 0.1, 0.3),
# unigrams=(1, 2, 3, 1, 3),
)
# vs = ids(vec)
with tf.Session() as sess:
print sess.run(ids)
if __name__ == '__main__':
test1()
输出
[4 0]