超级详细手把手讲解BiLSTM+CRF完成命名实体识别(三)

我们要进行第三个步骤的操作,就是使用test功能,

elif args.mode == 'test':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    print("test data: {}".format(test_size))
    #测试
    model.test(test_data)

和demo模式一样,首先使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型,然后也是创建了一个BiLSTM_CRF对象,并调用了它的build_graph()函数,这些咱们之前都讲过,调用这个函数就以为已经进行了模型的设置,只等传入feed_dict就可以得到预测结果了。

然后调用了model.test并传入测试数据。

    def test(self, test):
        saver = tf.train.Saver()
        with tf.Session(config=self.config) as sess:
            self.logger.info('=========== testing ===========')
            saver.restore(sess, self.model_path)
            label_list, seq_len_list = self.dev_one_epoch(sess, test)
            self.evaluate(label_list, seq_len_list, test)

一句saver.restore引入已经训练好的模型,然后直接使用dev_one_epoce函数和evaluate函数,这两个函数之前我们训练的时候也用到过,主要目的是检测训练集的训练效果在测试集表现怎么样,这里的话省略了训练的步骤,而是调用以前训练好的模型,直接进行测试集的检测。

其实dev_one_epoce就是得到了一个有字,真实标签,预测标签组成的矩阵,但是过程是通过predict_one_batch来得到预测标签的,这样就可以用于evaluate的准确率检验了。

这个模式还是比较简单的,和train模式差不多,只是少了训练的步骤。

你可能感兴趣的:(tensorflow,tensorflow)