tf.nn.sampled_softmax_loss

def sampled_softmax_loss(weights,
                         biases,
                         labels,
                         inputs,
                         num_sampled,
                         num_classes,
                         num_true=1,
                         sampled_values=None,
                         remove_accidental_hits=True,
                         partition_strategy="mod",
                         name="sampled_softmax_loss",
                         seed=None)

计算和返回sampled softmax训练损失。

在类别很多的情况下训练softmax分类器的高效的方法。

注意:仅在训练时使用这一采样操作,测试时还是用全部的类别,比如如下使用:

if mode == "train":
  loss = tf.nn.sampled_softmax_loss(
      weights=weights,
      biases=biases,
      labels=labels,
      inputs=inputs,
      ...,
      partition_strategy="div")
elif mode == "eval":
  logits = tf.matmul(inputs, tf.transpose(weights))
  logits = tf.nn.bias_add(logits, biases)
  labels_one_hot = tf.one_hot(labels, n_classes)
  loss = tf.nn.softmax_cross_entropy_with_logits_v2(
      labels=labels_one_hot,
      logits=logits)

 

你可能感兴趣的:(tf.nn.sampled_softmax_loss)