xlnet pytorch简易版代码解读


  • 安装
  • 参数分析
  • data_utils.py代码分析



首先clone XLNet-pytorch的源码

git clone https://github.com/graykode/xlnet-Pytorch && cd xlnet-Pytorch

# To use Sentence Piece Tokenizer(pretrained-BERT Tokenizer)
$ pip install pytorch_pretrained_bert


  • data 数据存放路径
  • tokenizer 分词
  • seq_len 序列长度,
  • reuse_len cache的长度,
  • perm_size 最长的排列长度
  • bi_data 是否双向的batch,
  • mask_alpha 多少词组成一个group
  • mask_beta 每个group里mask几个词
  • num_predict 预测多少个词
  • mem_len 缓存的长度
  • num_epoch 训练轮数



def _create_data(sp, input_paths, seq_len, reuse_len,
                bi_data, num_predict, mask_alpha, mask_beta):
    features = []

    f = open(input_paths, 'r')
    lines = f.readlines()
    input_data, sent_ids, sent_id = [], [], True

    for line in lines:
        tokens = sp.tokenize(line)
        cur_sent = sp.convert_tokens_to_ids(tokens)
        sent_ids.extend([sent_id] * len(cur_sent))
        sent_id = not sent_id

这里为了方便处理,作者只对单个文件进行了处理,在xlnet源代码中是对多个文件进行了处理,对于每一个文件(我们这里只有一个),最终是为了得到”input_data, sent_ids = [], []”两个list。input_data里是放到这个文件的每一个WordPiece对应的ID,而sent_ids用于判断句子的边界。
比如说对"This is the first sentence.this is the second sentence and also the end of the paragraph.",首先使用sp将其切分为[‘this’, ‘is’, ‘the’, ‘first’, ‘sentence’, ‘.’, ‘this’, ‘is’, ‘the’ ‘second’, ‘sentence’, ‘and’, ‘also’, ‘the’, ‘end’, ‘of’, ‘the’, ‘paragraph’, ‘.’],最后变成ID得到[2023, 2003, 1996, 2034, 6251, 1012, 2023, 2003, 1996, 2117, 6251, 1998, 2036, 1996 , 2203, 1997, 1996, 20423, 1012]。第一个句子"This is the first sentence"对应的sent_ids是[True, True, True, True, True, True],第二个句子对应的sent_ids是[False, … ,False]。于是,最后得到的input_data和sent_ids为:

input_data = [2023, 2003, 1996, 2034, 6251, 1012, 2023, 2003, 1996, 2117, 6251, 1998, 2036, 1996 , 2203, 1997, 1996, 20423, 1012]
sent_ids = [True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False]


    # shape of data : [1, 582]
    data = np.array([input_data], dtype=np.int64)
    sent_ids = np.array([sent_ids], dtype=np.bool)

    assert reuse_len < seq_len - 3

    data_len = data.shape[1]
    sep_array = np.array([SEP_ID], dtype=np.int64)
    cls_array = np.array([CLS_ID], dtype=np.int64)

    i = 0
    while i + seq_len <= data_len:
        inp = data[0, i: i + reuse_len]
        tgt = data[0, i + 1: i + reuse_len + 1]

        results = _split_a_and_b(
            data[0], # all line in one Text file.
            begin_idx=i + reuse_len,
            tot_len=seq_len - reuse_len - 3,

data[0] = [2023, 2003, 1996, 2034, 6251, 1012, 2023, 2003, 1996, 2117, 6251, 1998, 2036, 1996 , 2203, 1997, 1996, 20423, 1012]
sent_ids[0] = [True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False]
begin_idx = 0 + 4 = 4
tot_len = 8 - 4 - 3 = 1

        # unpack the results
        (a_data, b_data, label, _, a_target, b_target) = tuple(results)

        # sample ngram spans to predict
        reverse = bi_data
        if num_predict is None:
            num_predict_0 = num_predict_1 = None
            num_predict_1 = num_predict // 2
            num_predict_0 = num_predict - num_predict_1

        mask_0 = _sample_mask(sp, inp, mask_alpha, mask_beta, reverse=reverse,
        mask_1 = _sample_mask(sp, np.concatenate([a_data, sep_array, b_data,
                                                  sep_array, cls_array]),
                              mask_alpha, mask_beta,
                              reverse=reverse, goal_num_predict=num_predict_1)

        # concatenate data
        cat_data = np.concatenate([inp, a_data, sep_array, b_data,
                                   sep_array, cls_array])
        seg_id = ([0] * (reuse_len + a_data.shape[0]) + [0] +
                  [1] * b_data.shape[0] + [1] + [2])
        assert cat_data.shape[0] == seq_len
        assert mask_0.shape[0] == seq_len // 2
        assert mask_1.shape[0] == seq_len // 2

        # the last two CLS's are not used, just for padding purposes
        tgt = np.concatenate([tgt, a_target, b_target, cls_array, cls_array])
        assert tgt.shape[0] == seq_len

        is_masked = np.concatenate([mask_0, mask_1], 0)
        if num_predict is not None:
            assert np.sum(is_masked) == num_predict

        feature = {
            "input": cat_data,
            "is_masked": is_masked,
            "target": tgt,
            "seg_id": seg_id,
            "label": [label],
        i += reuse_len
    return features
