Bert文本分类run_classifier的预测模块修改

修改位置1:run_classifier.py model_fn() 函数中

源码1:

else:
	output_spec = tf.contrib.tpu.TPUEstimatorSpec(
		mode=mode, predictions=probabilities, scaffold_fn=scaffold_fn)

替换源码1:

elif mode == tf.estimator.ModeKeys.PREDICT:
    def metric_fn(logits,probabilities):
        predicted_classes = tf.argmax(logits, axis=1,output_type=tf.int32)
        return {
             'pred_class_ids': predicted_classes[:, tf.newaxis],
             'probabilities':probabilities,
             'logits': logits}                

    pred_metrics = metric_fn(logits,probabilities)   
    output_spec = tf.estimator.EstimatorSpec(
        mode=mode,predictions=pred_metrics)         

修改位置2:run_classifier.py main()函数中

源码2:

with tf.gfile.GFile(output_predict_file, "w") as writer:
    tf.logging.info("***** Predict results *****")        
    for prediction in result:
        output_line = "\t".join(
        	str(class_probability) for class_probability in prediction) + "\n" 
        writer.write(output_line)   

替换代码2:

with tf.gfile.GFile(output_predict_file, "w") as writer:
    tf.logging.info("***** Predict results *****")
    pred_true_nums = 0     #预测正确个数
    for test_sample_nums,prediction in enumerate(result,1):
        output_line = "\t".join(
            str(class_probability) for class_probability in prediction.items()) + "\n"
        pred_true_nums += int(prediction["pred_class_ids"])
        writer.write(output_line)
    writer.write("\n"+"".join("pred_accuracy:" +str(pred_true_nums/test_sample_nums)))

你可能感兴趣的:(tensorflow,深度学习,Bert,run_classfier,深度学习,Bert)