bert实现端到端继续预训练

    ACL2020好多paper都证明如果能有领域内的数据进行继续预训练,能够对模型最终的效果有较大的提升(通常为几个百分点),现有的继续预训练的的步骤为 create_pretraining_data(创建预训练的数据)--- run_pretraining (根据上面的数据进行继续预训练),但是现实中很多时候不想过多的IO(对磁盘有不可毁灭的伤害,而创建预训练数据产生大量的IO),针对这种需求将创建数据和预训练结合到一个代码中实现端到端。

一、首先查看继续预训练读取数据的函数为 input_fn_builder, 所以需要改写input_fn_builder函数实现端到端训练。

#这个代码就是从文件中读取数据的意思,目的就是为了替换这一句     
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))

替换为:

#改变成从生成器,也就是从create_pretraining_data中获取数据
d = tf.data.Dataset.from_generator(take_write_2_flow,(tf.string), (tf.TensorShape([])))

 二、那么上面的生成器器函数是怎样的呢,代码精髓第一个是通过yield产生生成器,第二个是通过上面的tf.data.Dataset.from_generator处理生成器

#这里采用了1000每批次去生成训练数据,并通过yield每一条数据产生生成器,这样不需要将所有数据写到磁盘也能进行端到端训练
def take_write_2_flow():
    tf.logging.set_verbosity(tf.logging.INFO)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.gfile.Glob(input_pattern))

    tf.logging.info("*** Reading from input files ***")
    for input_file in input_files:
        tf.logging.info("  %s", input_file)

    rng = random.Random(FLAGS.random_seed)

    file_need_precessing = []
    with tf.gfile.GFile(input_files[0], "r") as reader:
        while True:
            strings = reader.readline()
            if not strings:
                break
            if len(file_need_precessing) == 1000:
                # dosomething
                #for some stratrety
                input_file_one_process = file_need_precessing[:]
                instances = fusion_input_and_out(input_file_one_process, tokenizer, rng)
                writers = create_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
                                                           FLAGS.max_predictions_per_seq)
                for write in writers:
                    yield write
                # yield writers
                file_need_precessing = []
            file_need_precessing.append(strings)

            
def create_instance_to_example_files(instances, tokenizer, max_seq_length,
                                     max_predictions_per_seq):
    """Create TF example files from `TrainingInstance`s."""
    writers = []

    for (inst_index, instance) in enumerate(instances):
        if inst_index == 0:
            tf.logging.info("*** Example ***")
            tf.logging.info("tokens: %s" % " ".join(
                [tokenization.printable_text(x) for x in instance.tokens]))
        input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
        input_mask = [1] * len(input_ids)
        segment_ids = list(instance.segment_ids)
        assert len(input_ids) <= max_seq_length

        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        masked_lm_positions = list(instance.masked_lm_positions)
        masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
        masked_lm_weights = [1.0] * len(masked_lm_ids)

        while len(masked_lm_positions) < max_predictions_per_seq:
            masked_lm_positions.append(0)
            masked_lm_ids.append(0)
            masked_lm_weights.append(0.0)

        next_sentence_label = 1 if instance.is_random_next else 0

        features = collections.OrderedDict()
        features["input_ids"] = create_int_feature(input_ids)
        features["input_mask"] = create_int_feature(input_mask)
        features["segment_ids"] = create_int_feature(segment_ids)
        features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
        features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
        features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
        features["next_sentence_labels"] = create_int_feature([next_sentence_label])

        tf_example = tf.train.Example(features=tf.train.Features(feature=features))
        writers.append(tf_example.SerializeToString())

    return writers

你可能感兴趣的:(AI,learning,road,tensorflow,bert,pretrain,机器学习,python)