注意:在windows下运行sh文件,需要安装git工具,在…/Git/bin文件夹中,运行sh.exe后出现窗口,cd到需要运行的文件目录中,输入sh train.sh
运行。
在bert_classifier.py中写一个自定义的Myprocessor类,继承了run_classifier.py中的DataProcessor。
写一个自己的文本处理器,需要注意:
class MyProcessor(DataProcessor):
def get_test_examples(self, data_dir):
return self.create_examples(
self._read_tsv(os.path.join(data_dir, "test.data")), "test")
def get_train_examples(self, data_dir):
"""See base class."""
return self.create_examples(
self._read_tsv(os.path.join(data_dir, "train.data")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self.create_examples(
self._read_tsv(os.path.join(data_dir, "val.data")), "val")
def get_pred_examples(self, data_dir):
return self.create_examples(
self._read_tsv(os.path.join(data_dir, "pred.data")), "pred")
def get_labels(self):
"""See base class."""
return ["-1", "0", "1"]
def create_examples(self, lines, set_type, file_base=True):
"""Creates examples for the training and dev sets. each line is label+\t+text_a+\t+text_b """
examples = []
for (i, line) in tqdm(enumerate(lines)):
if file_base:
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text = tokenization.convert_to_unicode(line[1])
if set_type == "test" or set_type == "pred":
label = "0"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text, label=label)) # 对于分类任务,单输入单输出,只需要text_a,不需要text_b
return examples
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
processors = {
"setiment": MyProcessor
}
...
主要代码如下,生成的pb文件在api文件夹下
def serving_input_receiver_fn():
input_ids = tf.placeholder(dtype=tf.int64, shape=[None, FLAGS.max_seq_length], name='input_ids')
input_mask = tf.placeholder(dtype=tf.int64, shape=[None, FLAGS.max_seq_length], name='input_mask')
segment_ids = tf.placeholder(dtype=tf.int64, shape=[None, FLAGS.max_seq_length], name='segment_ids')
label_ids = tf.placeholder(dtype=tf.int64, shape=[None, ], name='unique_ids')
receive_tensors = {'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids,
'label_ids': label_ids}
features = {'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids, "label_ids": label_ids}
return tf.estimator.export.ServingInputReceiver(features, receive_tensors)
estimator.export_savedmodel(FLAGS.serving_model_save_path, serving_input_receiver_fn)
一键部署:
simple_tensorflow_serving --model_base_path="./api"
分为两种,一种是读取文件的,就是要预测的文本是tsv文件的,叫做file_base_client.py,另一个直接输入文本的是client.py。首先更改input_fn_builder,返回dataset,然后从dataset中取数据,转换为list格式,传入模型,返回结果。