【Bert】(十二)简易问答系统--源码解析(bert后处理模型+损失函数)

论文:https://arxiv.org/pdf/1810.04805.pdf

官方代码:GitHub - google-research/bert: TensorFlow code and pre-trained models for BERT

bert后处理模型

在run_squad.py中的create_model函数中,“bert后处理模型”代码为:

  final_hidden = model.get_sequence_output()

  final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
  batch_size = final_hidden_shape[0]
  seq_length = final_hidden_shape[1]
  hidden_size = final_hidden_shape[2]

  output_weights = tf.get_variable(
      "cls/squad/output_weights", [2, hidden_size],
      initializer=tf.truncated_normal_initializer(stddev=0.02))

  output_bias = tf.get_variable(
      "cls/squad/output_bias", [2], initializer=tf.zeros_initializer())

  final_hidden_matrix = tf.reshape(final_hidden,
                                   [batch_size * seq_length, hidden_size])
  logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
  logits = tf.nn.bias_add(logits, output_bias)

  logits = tf.reshape(logits, [batch_size, seq_length, 2])
  logits = tf.transpose(logits, [2, 0, 1])

  unstacked_logits = tf.unstack(logits, axis=0)

  (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])

  return (start_logits, end_logits)

 【Bert】(十二)简易问答系统--源码解析(bert后处理模型+损失函数)_第1张图片

最终得到的start_logits, end_logits,他们的形状都为【batchsize, seq_length】。

这种处理方式只适合一问一答的情况。

损失函数

      def compute_loss(logits, positions):
        one_hot_positions = tf.one_hot(
            positions, depth=seq_length, dtype=tf.float32)
        log_probs = tf.nn.log_softmax(logits, axis=-1)
        loss = -tf.reduce_mean(
            tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
        return loss

      start_positions = features["start_positions"]
      end_positions = features["end_positions"]

      start_loss = compute_loss(start_logits, start_positions)
      end_loss = compute_loss(end_logits, end_positions)

      total_loss = (start_loss + end_loss) / 2.0

在start_logits中概率最大的认为是起始位置,end_logits中概率最大的认为是终止位置。

根据这样的理念结合交叉熵损失,就能得到上述代码描述的情况。

你可能感兴趣的:(NLP,bert,深度学习,自然语言处理)