“基于常识知识的推理问题”源代码分析-最后总结

2021SC@SDUSC

根据前面数周的描述,我们已经对DrFact这个模型有了相当程度的了解。我们不仅通过对其源代码的解析,认识到了这个模型的算法究竟如何,同时也在此过程中了解了许多有关于机器学习、深度学习以及NLP相关的知识。在这次源代码分析中,我将对最后一个源文件进行分析,在这个过程中,我们将会对于DrFact模型完整的流程有一个更加详尽的认知。

一、run_drfact.py源文件代码分析

这次源代码分析的主体,在于run_drfact.py这个源文件,在这个源文件中,定义了许多类以及方法,其源代码行数也是所有源文件中最长的。可想而知,这个源文件在整个模型训练过程中的重要意义。那么接下来就是分析时间。

1.1 调用模块

毫无疑问,作为课题项目的最后收尾,本次调用的模块可谓最多。不仅有我们常用的基础模块,还有老熟人absl,numpy,tf等。不仅如此,这次还用到了albert和bert编码模块,这次的编码模块将不会再借鉴DrKit而是直接调用完整版来进行使用。

不过,话虽如此,我们依旧会在本次调用到我们已经分析过的DrKit模块内容search_utils以及部分我们在DrFact模型中之前介绍的模块。

import collections
import functools
import json
import os
import re
import time

from absl import flags
from albert import tokenization as albert_tokenization
from bert import modeling
from bert import optimization
from bert import tokenization as bert_tokenization
from language.labs.drfact import evaluate
from language.labs.drfact import input_fns
from language.labs.drfact import model_fns
from language.labs.drkit import search_utils
import numpy as np
import random
import tensorflow.compat.v1 as tf
# from tfdeterminism import patch
# patch()

from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
from tensorflow.contrib import memory_stats as contrib_memory_stats

1.2 flags参数

接下来是对于flags参数的描述,这次定义的参数数量也是前所未有的多,具体我就不一一描述了,不然多少有冗余之嫌。不过我相信,我贴出来的代码已经足以看的很清晰了。

FLAGS = flags.FLAGS

## Required parameters
flags.DEFINE_string(
    "bert_config_file", None,
    "The config json file corresponding to the pre-trained BERT model. "
    "This specifies the model architecture.")
flags.DEFINE_string("tokenizer_type", "bert_tokenization",
                    "The tokenizier type that the BERT model was trained on.")
flags.DEFINE_string("tokenizer_model_file", None,
                    "The tokenizier model that the BERT was trained with.")

flags.DEFINE_string("vocab_file", None,
                    "The vocabulary file that the BERT model was trained on.")

flags.DEFINE_string(
    "output_dir", None,
    "The output directory where the model checkpoints will be written.")

flags.DEFINE_string(
    "output_prediction_file", "test_predictions.json",
    "The output directory where the model checkpoints will be written.")

## Other parameters
flags.DEFINE_string("train_file", None, "JSON for training.")

flags.DEFINE_string("predict_file", None, "JSON for predictions.")
flags.DEFINE_string("predict_prefix", "dev", "JSON for predictions.")

flags.DEFINE_string("test_file", None, "JSON for predictions.")

flags.DEFINE_string("data_type", "onehop",
                    "Whether queries are `onehop` or `twohop`.")

flags.DEFINE_string("model_type", "drfact",
                    "Whether to use `drfact` or `drkit` model.")

flags.DEFINE_string(
    "init_checkpoint", None,
    "Initial checkpoint (usually from a pre-trained BERT model).")

flags.DEFINE_string("train_data_dir", None,
                    "Location of entity/mention/fact files for training data.")

flags.DEFINE_string("f2f_index_dir", None,
                    "Location of fact2fact files for training data.")

flags.DEFINE_string("test_data_dir", None,
                    "Location of entity/mention/fact files for test data.")

flags.DEFINE_string("model_ckpt_toload", "best_model",
                    "Name of the checkpoints.")

flags.DEFINE_string("test_model_ckpt", "best_model", "Name of the checkpoints.")

flags.DEFINE_string("embed_index_prefix", "bert_large", "Prefix of indexes.")

flags.DEFINE_integer("num_hops", 2, "Number of hops in rule template.")

flags.DEFINE_integer("max_entity_len", 4,
                     "Maximum number of tokens in an entity name.")

flags.DEFINE_integer(
    "num_mips_neighbors", 100,
    "Number of nearest neighbor mentions to retrieve for queries in each hop.")

flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")

flags.DEFINE_integer(
    "projection_dim", None, "Number of dimensions to project embeddings to. "
    "Set to None to use full dimensions.")

flags.DEFINE_integer(
    "max_query_length", 64,
    "The maximum number of tokens for the question. Questions longer than "
    "this will be truncated to this length.")

flags.DEFINE_bool("do_train", False, "Whether to run training.")

flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")

flags.DEFINE_bool("do_test", False, "Whether to run eval on the test set.")

flags.DEFINE_float(
    "subject_mention_probability", 0.0,
    "Fraction of training instances for which we use subject "
    "mentions in the text as opposed to canonical names.")

flags.DEFINE_integer("train_batch_size", 16, "Total batch size for training.")

flags.DEFINE_integer("predict_batch_size", 32,
                     "Total batch size for predictions.")

flags.DEFINE_float("learning_rate", 3e-5, "The initial learning rate for Adam.")

flags.DEFINE_float("num_train_epochs", 3.0,
                   "Total number of training epochs to perform.")

flags.DEFINE_float(
    "warmup_proportion", 0.1,
    "Proportion of training to perform linear learning rate warmup for. "
    "E.g., 0.1 = 10% of training.")

flags.DEFINE_integer("save_checkpoints_steps", 100,
                     "How often to save the model checkpoint.")

flags.DEFINE_integer("iterations_per_loop", 300,
                     "How many steps to make in each estimator call.")

flags.DEFINE_string("supervision", "fact",
                    "Type of supervision -- `fact` or `entity` or `fact+entity`.")

flags.DEFINE_float("entity_score_threshold", 1e-2,
                   "Minimum score of an entity to retrieve sparse neighbors.")
flags.DEFINE_float("fact_score_threshold", 1e-2,
                   "Minimum score of a fact to retrieve sparse neighbors.")

flags.DEFINE_float("self_follow_threshold", 5e-5,
                   "Minimum score of a fact to retrieve sparse neighbors.")
                  
flags.DEFINE_float("softmax_temperature", 2.,
                   "Temperature before computing softmax.")

flags.DEFINE_string(
    "sparse_reduce_fn", "max",
    "Function to aggregate sparse search results for a set of "
    "entities.")

flags.DEFINE_string("sparse_strategy", "dense_first",
                    "How to combine sparse and dense components.")

flags.DEFINE_boolean("intermediate_loss", False,
                     "Compute loss on intermediate layers.")

flags.DEFINE_boolean("light", False, "If true run in light mode.")
flags.DEFINE_boolean("is_excluding", False,
                     "If true exclude question and wrong choices' concepts.")

# flags.DEFINE_string(
#     "qry_layers_to_use", "-1",
#     "Comma-separated list of layer representations to use as the fixed "
#     "query representation.")

flags.DEFINE_string(
    "qry_aggregation_fn", "concat",
    "Aggregation method for combining the outputs of layers specified using "
    "`qry_layers`.")

flags.DEFINE_string(
    "entity_score_aggregation_fn", "max",
    "Aggregation method for combining the mention logits to entities.")

flags.DEFINE_float("question_dropout", 0.2,
                   "Dropout probability for question BiLSTMs.")

flags.DEFINE_integer("question_num_layers", 2,
                     "Number of layers for question BiLSTMs.")

flags.DEFINE_integer("num_preds", 100, "Use -1 for all predictions.")

flags.DEFINE_boolean(
    "ensure_answer_sparse", False,
    "If true, ensures answer is among sparse retrieval results"
    "during training.")

flags.DEFINE_boolean(
    "ensure_answer_dense", False,
    "If true, ensures answer is among dense retrieval results "
    "during training.")

flags.DEFINE_boolean(
    "train_with_sparse", True,
    "If true, multiplies logits with sparse retrieval results "
    "during training.")

flags.DEFINE_boolean(
    "predict_with_sparse", True,
    "If true, multiplies logits with sparse retrieval results "
    "during inference.")

flags.DEFINE_boolean("fix_sparse_to_one", True,
                     "If true, sparse search matrix is fixed to {0,1}.")

flags.DEFINE_boolean("l2_normalize_db", False,
                     "If true, pre-trained embeddings are normalized to 1.")

flags.DEFINE_boolean("load_only_bert", False,
                     "To load only BERT variables from init_checkpoint.")

flags.DEFINE_boolean(
    "use_best_ckpt_for_predict", False,
    "If True, loads the best_model checkpoint in model_dir, "
    "instead of the latest one.")

flags.DEFINE_bool("profile_model", False, "Whether to run profiling.")

flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")

flags.DEFINE_integer("random_seed", 1, "Random seed for reproducibility.")

flags.DEFINE_string(
    "tpu_name", None,
    "The Cloud TPU to use for training. This should be either the name "
    "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
    "url.")

flags.DEFINE_string(
    "tpu_zone", None,
    "[Optional] GCE zone where the Cloud TPU is located in. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")

flags.DEFINE_string(
    "gcp_project", None,
    "[Optional] Project name for the Cloud TPU-enabled project. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")

flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")

flags.DEFINE_integer(
    "num_tpu_cores", 8,
    "Only used if `use_tpu` is True. Total number of TPU cores to use.")

flags.DEFINE_bool(
    "verbose_logging", False,
    "If true, all of the warnings related to data processing will be printed. "
    "A number of warnings are expected for a normal SQuAD evaluation.")

flags.DEFINE_bool("debug", False,
                  "If true, only print the flags but not run anything.")

1.3 三个重要类

在本次的文件中,一共定义了三个极为重要的类,分别是QAConfig类,MIPSConfig类和FactMIPSConfig类。通过名字就可以看出来,QAConfig中存放的是QA系统的配置信息,MIPSConfig中存放的是最大内积搜索的配置信息,而FactMIPSConfig中存放的是事实矩阵最大内积搜索的配置信息。接下来,我将分别对它们进行一定程度的分析。

首先是QAConfig类,从注释可知,这里存放的是QA模型的超参数。

class QAConfig(object):
  """Hyperparameters for the QA model."""

  def __init__(self, qry_aggregation_fn, dropout,
               qry_num_layers, projection_dim, num_entities, max_entity_len,
               ensure_answer_sparse, ensure_answer_dense, train_with_sparse,
               predict_with_sparse, fix_sparse_to_one, supervision,
               l2_normalize_db, entity_score_aggregation_fn,
               entity_score_threshold, fact_score_threshold, self_follow_threshold,
               softmax_temperature, sparse_reduce_fn, intermediate_loss,
               train_batch_size, predict_batch_size, light, sparse_strategy,
               load_only_bert):
    # self.qry_layers_to_use = [int(vv) for vv in qry_layers_to_use.split(",")]
    self.qry_aggregation_fn = qry_aggregation_fn
    self.dropout = dropout
    self.qry_num_layers = qry_num_layers
    self.projection_dim = projection_dim
    self.num_entities = num_entities
    self.max_entity_len = max_entity_len
    self.load_only_bert = load_only_bert
    self.ensure_answer_sparse = ensure_answer_sparse
    self.ensure_answer_dense = ensure_answer_dense
    self.train_with_sparse = train_with_sparse
    self.predict_with_sparse = predict_with_sparse
    self.fix_sparse_to_one = fix_sparse_to_one
    self.supervision = supervision
    self.l2_normalize_db = l2_normalize_db
    self.entity_score_aggregation_fn = entity_score_aggregation_fn
    self.entity_score_threshold = entity_score_threshold
    self.fact_score_threshold = fact_score_threshold
    self.self_follow_threshold = self_follow_threshold
    self.softmax_temperature = softmax_temperature
    self.sparse_reduce_fn = sparse_reduce_fn
    self.intermediate_loss = intermediate_loss
    self.train_batch_size = train_batch_size
    self.predict_batch_size = predict_batch_size
    self.light = light
    self.sparse_strategy = sparse_strategy

其次是QAConfig类,从注释可知,这里存放的是对提到的信息索引进行MIPS的模型的超参数。

class MIPSConfig(object):
  """Hyperparameters for the MIPS model of mention index."""

  def __init__(self, ckpt_path, ckpt_var_name, num_mentions, emb_size,
               num_neighbors):
    self.ckpt_path = ckpt_path
    self.ckpt_var_name = ckpt_var_name
    self.num_mentions = num_mentions
    self.emb_size = emb_size
    self.num_neighbors = num_neighbors

再次是QAConfig类,从注释可知,这里存放的是事实索引进行MIPS的模型的超参数。

class FactMIPSConfig(object):
  """Hyperparameters for the MIPS model of fact index."""

  def __init__(self, ckpt_path, ckpt_var_name, num_facts, emb_size,
               num_neighbors):
    self.ckpt_path = ckpt_path
    self.ckpt_var_name = ckpt_var_name
    self.num_facts = num_facts
    self.emb_size = emb_size
    self.num_neighbors = num_neighbors

1.4 主函数main

由于中间夹有过于大量的函数,因此我们无需对他们的细节有过多关注,否则会显得十分冗长,让人忍不住发困。而直接从主函数main入手,我们同样可以对这些函数有一个明晰的认知,同时也可以认识到主函数的结构,因此我选择了直接从主函数入手。接下来,就是对主函数的分析过程。

在主函数的定义中,首先先进行一些有关随机种子的预设,诸如使用tensorflow的set_random_seed()函数等等,从而得以保证后续的随机过程不会出差错。

def main(_):
  """Main function."""
  tf.logging.set_verbosity(tf.logging.INFO)

  # Control the random seed.
  tf.set_random_seed(FLAGS.random_seed)
  tf.random.set_random_seed(FLAGS.random_seed)
  os.environ['PYTHONHASHSEED']=str(FLAGS.random_seed)
  random.seed(FLAGS.random_seed)
  np.random.seed(FLAGS.random_seed)

  if FLAGS.debug:
    print(FLAGS)
    return

接下来,对数据类型和模型类型以及BERT进行决定。 

  1. 如果数据类型是opencsr,则将数据集类设为OpenCSRDataset,然后将eval_fn设为opencsr_eval_fn。
  2. 如果模型类型是drkit,则使用model_fns中的create_drkit_model来构造drkit模型;而如果是drfact,则使用model_fns中的create_drfact_model来构造drfact模型。这里用意是用来准备进行比对性能使用的。
  3. 最后,加载BERT模型,准备用于编码。
  # Decide data type.
  if FLAGS.data_type == "opencsr":
    dataset_class = input_fns.OpenCSRDataset
    eval_fn = evaluate.opencsr_eval_fn # only eval the recall (R1@300) now

  # Decide model type.
  if FLAGS.model_type == "drkit":
    create_model_fn = functools.partial(
        model_fns.create_drkit_model, num_hops=FLAGS.num_hops)
  elif FLAGS.model_type == "drfact":
    create_model_fn = functools.partial(
        model_fns.create_drfact_model, num_hops=FLAGS.num_hops)
  else:
    tf.logging.info("Wrong model_type...")
  # Load BERT.
  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

下一步,加载提到的信息文件和实体文件,二者文件路径分别是使用参数train_data_dir和"entities.json"或"subparas.json"拼接之后得到的。值得一提的是,如果模型是drkit,还应该加载mention2text文件和all_mentions文件。

  # Load mention and entity files.
  tf.logging.info("Loading metadata about entities and facts...")
  mention2text = None
  all_mentions = None
  if FLAGS.model_type == "drkit":
    mention2text = json.load(
        tf.gfile.Open(os.path.join(FLAGS.train_data_dir, "mention2text.json")))
    all_mentions = np.load(
        tf.gfile.Open(os.path.join(FLAGS.train_data_dir, "mentions.npy"), "rb"))
  entity2id, entity2name = json.load(
      tf.gfile.Open(os.path.join(FLAGS.train_data_dir, "entities.json")))
  entityid2name = {str(i): entity2name[e] for e, i in entity2id.items()}
  all_paragraphs = json.load(
      tf.gfile.Open(os.path.join(FLAGS.train_data_dir, "subparas.json")))

 根据已经获得的信息,构造QAConfig类的对象,即封装QA模型的超参数。

  qa_config = QAConfig(
      qry_aggregation_fn=FLAGS.qry_aggregation_fn,
      dropout=FLAGS.question_dropout,
      qry_num_layers=FLAGS.question_num_layers,
      projection_dim=FLAGS.projection_dim,
      load_only_bert=FLAGS.load_only_bert,
      num_entities=len(entity2id),
      max_entity_len=FLAGS.max_entity_len,
      ensure_answer_sparse=FLAGS.ensure_answer_sparse,
      ensure_answer_dense=FLAGS.ensure_answer_dense,
      train_with_sparse=FLAGS.train_with_sparse,
      predict_with_sparse=FLAGS.predict_with_sparse,
      fix_sparse_to_one=FLAGS.fix_sparse_to_one,
      supervision=FLAGS.supervision,
      l2_normalize_db=FLAGS.l2_normalize_db,
      entity_score_aggregation_fn=FLAGS.entity_score_aggregation_fn,
      entity_score_threshold=FLAGS.entity_score_threshold,
      fact_score_threshold=FLAGS.fact_score_threshold,
      self_follow_threshold=FLAGS.self_follow_threshold,
      softmax_temperature=FLAGS.softmax_temperature,
      sparse_reduce_fn=FLAGS.sparse_reduce_fn,
      intermediate_loss=FLAGS.intermediate_loss,
      light=FLAGS.light,
      sparse_strategy=FLAGS.sparse_strategy,
      train_batch_size=FLAGS.train_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)

接下来,根据模型类型进行判断。如果模型类型是drkit,则将fact_mips_config设为None,表明构建MIPSConfig类对象封装的是为drkit模型准备的对提到的信息索引进行MIPS的模型的超参数;如果模型是drfact,则将mips_config设为None,表明构建FactMIPSConfig类对象封装的是为drkit模型准备的对事实索引进行MIPS的模型的超参数。验证输入标志或抛出异常。

  if FLAGS.model_type == "drkit":
    fact_mips_config = None
    mips_config = MIPSConfig(
        ckpt_path=os.path.join(FLAGS.train_data_dir,
                              "%s_mention_feats" % FLAGS.embed_index_prefix),
        ckpt_var_name="db_emb",
        num_mentions=len(mention2text),
        emb_size=FLAGS.projection_dim * 2,
        num_neighbors=FLAGS.num_mips_neighbors)
  elif FLAGS.model_type == "drfact":
    mips_config = None
    fact_mips_config = FactMIPSConfig(
        ckpt_path=os.path.join(FLAGS.train_data_dir,
                              "%s_fact_feats" % FLAGS.embed_index_prefix),
        ckpt_var_name="fact_db_emb",
        num_facts=len(all_paragraphs),
        emb_size=FLAGS.projection_dim * 2,
        num_neighbors=FLAGS.num_mips_neighbors)
  validate_flags_or_throw()

 保存训练的参数。

  tf.gfile.MakeDirs(FLAGS.output_dir)

  # Save training flags.
  if FLAGS.do_train:
    json.dump(tf.app.flags.FLAGS.flag_values_dict(),
              tf.gfile.Open(os.path.join(FLAGS.output_dir, "flags.json"), "w"))

然后根据参数tokenizer_type的值进行判断。如果值为bert_tokenization,那么构建bert的分词器;如果值为albert_tokenization,那么构建albert的分词器。接下来判断是否用到tpu,用到的话则构建tpu簇解释器。

  # tokenizer = tokenization.FullTokenizer(
  #     vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
  if FLAGS.tokenizer_type == "bert_tokenization":
    tokenizer = bert_tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=True)
  elif FLAGS.tokenizer_type == "albert_tokenization":
    tokenizer = albert_tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file,
        do_lower_case=False,
        spm_model_file=FLAGS.tokenizer_model_file)
  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2

使用tensorflow.estimator.tpui.Runconfig,将准备好的配置信息配置到run_config对象中。

  # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.25)
  
  session_config = tf.ConfigProto()
  session_config.gpu_options.allow_growth = True
  # session_config.gpu_options.per_process_gpu_memory_fraction = 0.25
  # session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
  session_config.log_device_placement = False

  run_config = tf.estimator.tpu.RunConfig(
      tf_random_seed=FLAGS.random_seed, # important
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      keep_checkpoint_max=3,
      tpu_config=tf.estimator.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_tpu_cores,
          per_host_input_for_training=is_per_host),
      session_config=session_config)

 接下来则是重复雷同fact2fact.py文件那篇博客中的参数判断是否执行的方案,就不多做赘述了。其实并不算特别难理解,只要回过头去翻看一下之前fact2fact.py的那篇博客就能对这个步骤有一个比较明晰的理解了。

  num_train_steps = None
  num_warmup_steps = None
  if FLAGS.num_preds < 0:
    FLAGS.num_preds = len(entity2id)
  if FLAGS.do_train:
    train_dataset = dataset_class(
        in_file=FLAGS.train_file,
        tokenizer=tokenizer,
        subject_mention_probability=FLAGS.subject_mention_probability,
        max_qry_length=FLAGS.max_query_length,
        is_training=True,
        entity2id=entity2id,
        tfrecord_filename=os.path.join(FLAGS.output_dir, "train.tf_record"))
    num_train_steps = int(train_dataset.num_examples / FLAGS.train_batch_size *
                          FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
  if FLAGS.do_predict:
    eval_dataset = dataset_class(
        in_file=FLAGS.predict_file,
        tokenizer=tokenizer,
        subject_mention_probability=0.0,
        max_qry_length=FLAGS.max_query_length,
        is_training=False,
        entity2id=entity2id,
        tfrecord_filename=os.path.join(
            FLAGS.output_dir, "eval.%s.tf_record" % FLAGS.predict_prefix))
    qa_config.predict_batch_size = FLAGS.predict_batch_size
  summary_obj = None
  # summary_obj = summary.TPUSummary(FLAGS.output_dir,
  #                                  FLAGS.save_checkpoints_steps)
  model_fn = model_fn_builder(
      bert_config=bert_config,
      qa_config=qa_config,
      mips_config=mips_config,
      fact_mips_config=fact_mips_config,
      init_checkpoint=FLAGS.init_checkpoint,
      e2m_checkpoint=os.path.join(FLAGS.train_data_dir, "ent2ment.npz"),
      m2e_checkpoint=os.path.join(FLAGS.train_data_dir, "coref.npz"),
      e2f_checkpoint=os.path.join(FLAGS.train_data_dir, "ent2fact_500.npz"),
      # Note: use a hp.
      f2e_checkpoint=os.path.join(FLAGS.train_data_dir, "fact_coref.npz"),
      f2f_checkpoint=os.path.join(FLAGS.f2f_index_dir, "fact2fact.npz"),
      entity_id_checkpoint=os.path.join(FLAGS.train_data_dir, "entity_ids"),
      entity_mask_checkpoint=os.path.join(FLAGS.train_data_dir, "entity_mask"),
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu,
      create_model_fn=create_model_fn,
      summary_obj=summary_obj)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  if FLAGS.do_train or FLAGS.do_predict:
    estimator = tf.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

  if FLAGS.do_train:
    tf.logging.info("***** Running training *****")
    tf.logging.info("  Num orig examples = %d", train_dataset.num_examples)
    tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Num steps = %d", num_train_steps)
    train(train_dataset, estimator, num_train_steps)

  if FLAGS.do_predict:
    continuous_eval(
        eval_dataset,
        estimator,
        mention2text,
        entityid2name,
        qa_config.supervision,
        eval_fn,
        paragraphs=all_paragraphs,
        mentions=all_mentions)

综上,这就是对run_drfact.py文件的所有源码的解析了。

二、总结

当我写到这里的时候,这个学期的“软件工程应用与实践”课程已然走到了尾声。回忆过往,不禁让我十分感慨。从最初的时候只是为了赶上末班车而急哄哄地跑到老师办公室却发现自己还没组队,到和临时搭伙的队友一起商议选题细节,再到因为之前的项目过难而被迫转换课题,最后再到如今一个学期下来,对目前手头的课题工作有了很大程度上认知的提升。在这个过程中,我从最初对NLP的毫无头绪,到现在至少能对QA系统问题有一个比较大致的了解,我已经感到十分满足。

对于这个课题,其实我是比较歉疚而惋惜的,因为我自己这学期身体状况并不是很好,我经常是有了很好的想法却没有办法施行,有一些感悟却常常没法及时地记下来,乃至错失了很好的理解和点子。我对理论知识的不断精进,却无法直接应用出来,在苦恼之余,只能够对于尚未完成的工作继续深挖耕耘,这既是幸事,也是不幸。实际上,我也不止一次尝试调配环境,做出许多能够适配工作的努力,但是让人沮丧的是,我的环境似乎一直调配的不是很理想,这也导致了我对于当前课题项目的代码理解需要花更大的成本。

不过无论如何,我还是非常感谢我的老师以及实验室的学长,他们在我们小组遇到很大的瓶颈的时候主动伸出援手,让我们小组可以换一个课题。这让我们在课题项目不至于难产的同时,也开拓了我们对于NLP的理解,得以从另一个角度看待NLP相关的问题。从文本抽取到QA系统,我们可以在对NLP问题有认识开始时,便能够从一个很大的角度看待NLP,认识到其中许许多多的分支领域,这不由得让我对于NLP能够始终保持一个很高程度的兴趣。

不仅如此,在本次对于开放式常识知识的推理问题的研究过程中,我同样也有了许多更加专业化的认知。词向量、Transformer、BERT、MIPS这些名词的理解也好,对于不断迭代的文本信息存储方式也罢,亦或是对于超图这样的模型的认知等等,这些专业的知识都极大地增长了我的见闻,让我得以在有关NLP的研究中不再如同盲人摸象一般胡乱臆测,而是有所根据地进行思考。哪怕仅仅只是所谓的管中窥豹,那也是难得的财富了。

此外,在这次课题项目的研究过程中,我同样得以反哺到这学期我选择的机器学习这一门课上。在对于NLP的研究过程中,我对文本类属性的认知得以超越当前课程学习中得到的理解,而对未来的发展有相当程度的影响。如果我未来有机会的话,我一定会尝试把如今学习到的NLP知识应用到未来的学习生活和工作之中。

这学期这门课程给我带来了精神上的极大富足和充实,让我的眼光更加开阔了,我真心实意地对这一趟旅程感到快乐。这就是我对这一个学期这门课程的一个总结。

你可能感兴趣的:(深度学习,自然语言处理,人工智能)