bert finetune 分类模型预加载

bert 代码中使用了TensorFlow的高级API estimator,但是这样训练出来的模型是不支持预加载到内存当中的,每次进行预测都要加载一遍模型,离线处理的话还可以接受,如果要在线预测的话效率就会大打折扣,这里提出一个解决方案。
google 开源的TensorFlow serving是专门针对TensorFlow框架训练出来的模型设计的一个服务,可以非常方便的启动预测服务,并支持grpc和restful。同时有GPU的版本,可以达到毫秒级预测。

  • 针对run_classifier
    这个比较简单,首先增加一个函数
def serving_input_fn():
    label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
    input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
    input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
    segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
        'label_ids': label_ids,
        'input_ids': input_ids,
        'input_mask': input_mask,
        'segment_ids': segment_ids,
    })()
    return input_fn

其中label_ids,input_ids,input_mask,segment_ids就是模型的输入,可以到源码中找到,其变量类型也可一并找到。然后在main()函数中的

if FLAGS.do_train:
	...
	estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

后面加上

estimator._export_to_tpu = False
estimator.export_savedmodel(path_to_save_model, serving_input_fn)

就可以了。这样模型训练完后就会把pb模型保存到path_to_save_model路径下,然后用TensorFlow serving加载模型即可启动服务。
注意:这个模型的输入是经过bert 编码后的文本信息,因此调用该服务前需要首先对文本进行处理,这里源码中已经给出了,由于在线预测的服务每次输入都是单条数据,因此源码函数

def convert_single_example(ex_index, example, label_list, max_seq_length,
                           tokenizer):
  """Converts a single `InputExample` into a single `InputFeatures`."""

  if isinstance(example, PaddingInputExample):
    return InputFeatures(
        input_ids=[0] * max_seq_length,
        input_mask=[0] * max_seq_length,
        segment_ids=[0] * max_seq_length,
        label_id=0,
        is_real_example=False)

  label_map = {}
  for (i, label) in enumerate(label_list):
    label_map[label] = i

  tokens_a = tokenizer.tokenize(example.text_a)
  tokens_b = None
  if example.text_b:
    tokens_b = tokenizer.tokenize(example.text_b)

  if tokens_b:
    # Modifies `tokens_a` and `tokens_b` in place so that the total
    # length is less than the specified length.
    # Account for [CLS], [SEP], [SEP] with "- 3"
    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
  else:
    # Account for [CLS] and [SEP] with "- 2"
    if len(tokens_a) > max_seq_length - 2:
      tokens_a = tokens_a[0:(max_seq_length - 2)]
    
  tokens = []
  segment_ids = []
  tokens.append("[CLS]")
  segment_ids.append(0)
  for token in tokens_a:
    tokens.append(token)
    segment_ids.append(0)
  tokens.append("[SEP]")
  segment_ids.append(0)

  if tokens_b:
    for token in tokens_b:
      tokens.append(token)
      segment_ids.append(1)
    tokens.append("[SEP]")
    segment_ids.append(1)

  input_ids = tokenizer.convert_tokens_to_ids(tokens)

  # The mask has 1 for real tokens and 0 for padding tokens. Only real
  # tokens are attended to.
  input_mask = [1] * len(input_ids)

  # Zero-pad up to the sequence length.
  while len(input_ids) < max_seq_length:
    input_ids.append(0)
    input_mask.append(0)
    segment_ids.append(0)

  assert len(input_ids) == max_seq_length
  assert len(input_mask) == max_seq_length
  assert len(segment_ids) == max_seq_length

  label_id = label_map[example.label]
  if ex_index < 5:
    tf.logging.info("*** Example ***")
    tf.logging.info("guid: %s" % (example.guid))
    tf.logging.info("tokens: %s" % " ".join(
        [tokenization.printable_text(x) for x in tokens]))
    tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
    tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
    tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
    tf.logging.info("label: %s (id = %d)" % (example.label, label_id))

  feature = InputFeatures(
      input_ids=input_ids,
      input_mask=input_mask,
      segment_ids=segment_ids,
      label_id=label_id,
      is_real_example=True)
  return feature

就可以满足要求了,只需做少许调整即可调用TensorFlow serving起的预测服务。
推荐使用docker 部署TensorFlow serving,这样可以避免很多端口问题,然后在docker内部利用flask和requests搭建自己的在线预测服务就搞定了。
bert finetune 问答模型预加载看这里。

你可能感兴趣的:(python,bert,TensorFlow,自然语言处理,深度学习,tensorflow,机器学习,神经网络)