transformers.generator_utils函数源码解析之beam_search

        主要记录transformers库中generator_utils函数的beam_search方法,以源码的方式加深理解,重要的步骤都在后面添加了注释

#beam_search主体函数

while True:

    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) #整理下一步decoder所需数据

    outputs = self(
        **model_inputs,
        return_dict=True,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
    )#将cur-step的数据传入模型,如bart即为decoder模型,先decoder_input做self- attention,然后encoder_input与decoder_input做cross-attention,返回与decoder_input相同size的embedding

    if synced_gpus and this_peer_finished:
        cur_len = cur_len + 1
        continue  # don't waste resources running the code we don't need

    next_token_logits = outputs.logits[:, -1, :] #获取cur-step的预测分布,即最后一个输出的vocab分布

    # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
    # cannot be generated both before and after the `F.log_softmax` operation.
    next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) #定制化logits调整,可不管
    #概率值较小的数值,经过一些列连乘后,会更小,为防止下溢(underflowing the floating point numbers),将其计算转换为取其对数,然后相加的过程
    next_token_scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)

    next_token_scores = logits_processor(input_ids, next_token_scores) #logits处理器,包括MinLengthLogitsProcessor等一系列类,调整cur-step输出分布的分数
    next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) #整合开始到cur-step的总分数,即前面生成句子的分数加上cur-step vocab_size的分数

    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (next_token_scores,)
        if output_attentions:
            decoder_attentions += (
                (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
            )
            if self.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += (
                (outputs.decoder_hidden_states,)
                if self.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )

    # reshape for beam search
    vocab_size = next_token_scores.shape[-1]
    next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) #摊平选topn,摊平的好处在于能够更全局性的找到高置信度的字

    next_token_scores, next_tokens = torch.topk(
        next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
    )

    next_indices = next_tokens // vocab_size #除于词表,消除摊平,得到cur-step偏置位置
    next_tokens = next_tokens % vocab_size #对词表大小取余,得到cur-step相对于词表的下标

    # stateless
    beam_outputs = beam_scorer.process(
        input_ids,
        next_token_scores,
        next_tokens,
        next_indices,
        pad_token_id=pad_token_id,
        eos_token_id=eos_token_id,
    ) #处理cur-step结果,下面会展开解释
    beam_scores = beam_outputs["next_beam_scores"]
    beam_next_tokens = beam_outputs["next_beam_tokens"]
    beam_idx = beam_outputs["next_beam_indices"]

    input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

    model_kwargs = self._update_model_kwargs_for_generation(
        outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
    )
    if model_kwargs["past"] is not None:
        model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)

    # increase cur_len
    cur_len = cur_len + 1

    if beam_scorer.is_done or stopping_criteria(input_ids, scores):
        if not synced_gpus:
            break
        else:
            this_peer_finished = True

  sequence_outputs = beam_scorer.finalize(
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
        ) #beam_search得到结果list,根据每个分支线的score排序得到最终结果,下面有具体函数解析
  最终返回解码结果sequence_outputs

剩下部分比较重要的是处理cur-step的beam_scorer.process和beam_scorer.finalize函数

def beam_scorer.process(
        self,
        input_ids: torch.LongTensor,
        next_scores: torch.FloatTensor,
        next_tokens: torch.LongTensor,
        next_indices: torch.LongTensor,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
    ) -> Tuple[torch.Tensor]:
    cur_len = input_ids.shape[-1]
    batch_size = len(self._beam_hyps)
    assert batch_size == (input_ids.shape[0] // self.group_size)

    device = input_ids.device # self.group_size就是num_beam_search
    next_beam_scores = torch.zeros((batch_size, self.group_size),  dtype=next_scores.dtype, device=device) 
    next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
    next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

    for batch_idx, beam_hyp in enumerate(self._beam_hyps): #逐句处理cur-step
        if self._done[batch_idx]: #解码过程中设置eos_id结束id,若next-id命中,则标记结束
            assert (
                len(beam_hyp) >= self.num_beams
            ), f"Batch can only be done if at least {self.num_beams} beams have been generated"
            assert (
                eos_token_id is not None and pad_token_id is not None
            ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
            # pad the batch
            next_beam_scores[batch_idx, :] = 0
            next_beam_tokens[batch_idx, :] = pad_token_id #填充的pad token在后续都会被过滤掉
            next_beam_indices[batch_idx, :] = 0
            continue

        # next tokens for this sentence
        beam_idx = 0
        for beam_token_rank, (next_token, next_score, next_index) in enumerate(
            zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
        ):
            batch_beam_idx = batch_idx * self.group_size + next_index #因为代码会摊平,所以这里下标加了偏移量,表示cur-step取得的token是从num_beam中的哪个位置取的,可用于最终生成句子追溯。
            # add to generated hypotheses if end of sentence
            if (eos_token_id is not None) and (next_token.item() == eos_token_id):
                # if beam_token does not belong to top num_beams tokens, it should not be added
                is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
                if is_beam_token_worse_than_top_num_beams:
                    continue
                beam_hyp.add(
                    input_ids[batch_beam_idx].clone(),
                    next_score.item(),
                )
            else:
                # add next predicted token since it is not eos_token
                next_beam_scores[batch_idx, beam_idx] = next_score
                next_beam_tokens[batch_idx, beam_idx] = next_token
                next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
                beam_idx += 1

            # once the beam for next step is full, don't add more tokens to it.
            if beam_idx == self.group_size:
                break

        if beam_idx < self.group_size:
            raise ValueError(
                f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
            )

        # Check if we are done so that we can save a pad step if all(done)
        self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
            next_scores[batch_idx].max().item(), cur_len
        )#判断当前句子是否结束

    return UserDict(
        {
            "next_beam_scores": next_beam_scores.view(-1),
            "next_beam_tokens": next_beam_tokens.view(-1),
            "next_beam_indices": next_beam_indices.view(-1),
        }
    )

def beam_scorer.finalize(
    self,
    input_ids: torch.LongTensor,
    final_beam_scores: torch.FloatTensor,
    final_beam_tokens: torch.LongTensor,
    final_beam_indices: torch.LongTensor,
    max_length: int,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[int] = None,
) -> Tuple[torch.LongTensor]:
    batch_size = len(self._beam_hyps)

    # finalize all open beam hypotheses and add to generated hypotheses
    for batch_idx, beam_hyp in enumerate(self._beam_hyps):
        if self._done[batch_idx]: #解码过程中出现eos_id已经加入beam_hyp队列
            continue

        # all open beam hypotheses are added to the beam hypothesis
        # beam hypothesis class automatically keeps the best beams
        for beam_id in range(self.num_beams):
            batch_beam_idx = batch_idx * self.num_beams + beam_id
            final_score = final_beam_scores[batch_beam_idx].item()
            final_tokens = input_ids[batch_beam_idx]
            beam_hyp.add(final_tokens, final_score) #将超过最大长度的解码结果加入list

    # select the best hypotheses
    sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
    best = []
    best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)

    # retrieve best hypotheses
    for i, beam_hyp in enumerate(self._beam_hyps): #对beam_search搜索到的进行排序,获取top-score的句子作为输出
        sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) #beam_hyp.beams组成为[score, decode_token_list]
        for j in range(self.num_beam_hyps_to_keep):
            best_hyp_tuple = sorted_hyps.pop()
            best_score = best_hyp_tuple[0]
            best_hyp = best_hyp_tuple[1]
            sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)

            # append to lists
            best.append(best_hyp)
            best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

    # prepare for adding eos
    sent_max_len = min(sent_lengths.max().item() + 1, max_length)
    decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
    # shorter batches are padded if needed
    if sent_lengths.min().item() != sent_lengths.max().item():
        assert pad_token_id is not None, "`pad_token_id` has to be defined"
        decoded.fill_(pad_token_id)

    # fill with hypotheses and eos_token_id if the latter fits in
    for i, hypo in enumerate(best):
        decoded[i, : sent_lengths[i]] = hypo
        if sent_lengths[i] < max_length:
            decoded[i, sent_lengths[i]] = eos_token_id #最后一个字添加eos_id
    return UserDict(
        {
            "sequences": decoded,
            "sequence_scores": best_scores,
        }
    )

你可能感兴趣的:(nlp学习笔记,python,ide)