"""
transformer的generation_beam_search.py中beam_search过程
当decoder的输入是[N,1],N为batch_size,设置beams=k,将输入转化为[N*k,1]
输入放入到decoder中生成了logits,形状为[N*k,T],T为总的token数
logits和历史beam_score相加成为新的beam_score,进行topk排序,获取next_beam_scores、next_beam_index、next_beam_tokens
beam_hyps存储过程:通过上述next_beam_*,判断next_token是否是,是则存,不是则仍然挑选出beams=k个next_beam进行下一次decoder
代码实现基于一个数,生成一组连续的数,遇到末尾数为9则终止。
"""
import torch
from typing import *
from abc import ABC, abstractmethod
from collections import UserDict
import torch
#from .file_utils import add_start_docstrings
class BeamScorer(ABC):
"""
Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and
:meth:`~transformers.PretrainedModel.beam_sample`.
"""
@abstractmethod
#@add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
**kwargs
) -> Tuple[torch.Tensor]:
raise NotImplementedError("This is an abstract method.")
@abstractmethod
#@add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
def finalize(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
**kwargs
) -> torch.LongTensor:
raise NotImplementedError("This is an abstract method.")
class BeamSearchScorer(BeamScorer):
def __init__(
self,
batch_size: int,
max_length: int,
num_beams: int,
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1
):
self.max_length = max_length
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self.num_beam_groups = num_beam_groups
self.group_size = self.num_beams // self.num_beam_groups
self._is_init = False
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
max_length=self.max_length,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
)
for _ in range(batch_size)
]
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
)
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
raise ValueError(
f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` "
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)
@property
def is_done(self) -> bool:
return self._done.all()
def process(self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps)
assert batch_size == (input_ids.shape[0] // self.group_size)
device = input_ids.device
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:#生成的序列彻底完成情况下,依然设置next_beam_*
assert (
len(beam_hyp) >= self.num_beams
), "Batch can only be done if at least {} beams have been generated".format(self.num_beams)
assert (
eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
# pad the batch
next_beam_scores[batch_idx, :] = 0
next_beam_tokens[batch_idx, :] = pad_token_id
next_beam_indices[batch_idx, :] = 0
continue
# next tokens for this sentence
beam_idx = 0
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
continue
beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
)
else:
# add next predicted token since it is not eos_token
next_beam_scores[batch_idx, beam_idx] = next_score
next_beam_tokens[batch_idx, beam_idx] = next_token
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
beam_idx += 1
# once the beam for next step is full, don't add more tokens to it.
if beam_idx == self.group_size:
break
if beam_idx < self.group_size:
raise ValueError(
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
)
# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
)
return UserDict(
{
"next_beam_scores": next_beam_scores.view(-1),
"next_beam_tokens": next_beam_tokens.view(-1),
"next_beam_indices": next_beam_indices.view(-1),
}
)
def finalize(
self,
input_ids: torch.LongTensor,
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
continue
# all open beam hypotheses are added to the beam hypothesis
# beam hypothesis class automatically keeps the best beams
for beam_id in range(self.num_beams):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score)
# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
best = []
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
# retrieve best hypotheses
for i, beam_hyp in enumerate(self._beam_hyps):
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep):
best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0]
best_hyp = best_hyp_tuple[1]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
# append to lists
best.append(best_hyp)
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
# prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id)
# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < self.max_length:
decoded[i, sent_lengths[i]] = eos_token_id
return UserDict(
{
"sequences": decoded,
"sequence_scores": best_scores,
}
)
class BeamHypotheses:
def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self, hyp: torch.LongTensor, sum_logprobs: float):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_next_scores[0][1]]
self.worst_score = sorted_next_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
class ToyDecoder():
#@torch.no_grad()
def generate(self,
input_ids: Optional[torch.LongTensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
**model_kwargs,
) -> Union[torch.LongTensor]:
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
# set input_ids as decoder_input_ids
if "decoder_input_ids" in model_kwargs:
input_ids = model_kwargs.pop("decoder_input_ids")
else:
input_ids = self._prepare_decoder_input_ids_for_generation(
input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id
)
logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
encoder_input_ids=input_ids,#encoder_input_ids
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
)
is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False
if is_beam_gen_mode:
batch_size = input_ids.shape[0]
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
self.device = input_ids.device
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
)
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=True, **model_kwargs
)
return self.beam_search(
input_ids,
beam_scorer,
logits_processor=logits_processor,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
)
def _prepare_decoder_input_ids_for_generation(
self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None
) -> torch.LongTensor:
#取输入的最后一个字作为输入
decoder_input_ids = input_ids[:,-1].unsqueeze(-1)
return decoder_input_ids
@staticmethod
def _expand_inputs_for_generation(
input_ids: torch.LongTensor,
expand_size: int = 1,
is_encoder_decoder: bool = False,
attention_mask: torch.LongTensor = None,
#encoder_outputs: ModelOutput = None,
**model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
expanded_return_idx = (
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
)
input_ids = input_ids.index_select(0,expanded_return_idx)
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
if is_encoder_decoder:
pass
# assert encoder_outputs is not None
# encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
# 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
# )
# model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs
def _get_logits_processor(
self,
repetition_penalty: float,
no_repeat_ngram_size: int,
encoder_no_repeat_ngram_size: int,
encoder_input_ids: torch.LongTensor,
bad_words_ids: List[List[int]],
min_length: int,
eos_token_id: int,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
num_beams: int,
num_beam_groups: int,
diversity_penalty: float,
) :
return None
def beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[List] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
**model_kwargs,
) -> Union[torch.LongTensor]:
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if return_dict_in_generate and self.config.is_encoder_decoder:
# encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
# encoder_hidden_states = (
# model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
# )
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
assert (
num_beams * batch_size == batch_beam_size
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:,1:] = -1e9 #这个是针对起始位置是同一个字符,比如,情况设置的,目的是这样避免topk的值是一样的。
beam_scores = beam_scores.view((batch_size * num_beams,))
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
#next_token_logits = outputs.logits[:, -1, :]
next_token_logits = outputs['logits'][:, -1, :]
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
next_token_scores = next_token_logits/100
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
next_indices = next_tokens // vocab_size
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
cur_len = cur_len + 1
# model_kwargs = self._update_model_kwargs_for_generation(
# outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
# )
# if model_kwargs["past"] is not None:
# model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
#
# if beam_scorer.is_done:
# break
sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
)
return sequence_outputs["sequences"]
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
# if past is not None:
# decoder_input_ids = decoder_input_ids[:, -1:] #取一个batch每个序列最后一个token
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def __call__(#其实是forward方法,这里简化decoder计算结果
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_shape = decoder_input_ids.size()
decoder_input_ids = decoder_input_ids.view(-1, input_shape[-1])
shape = tuple(decoder_input_ids.shape)+(100,) #设置0-99个数字的预测
lm_logits = torch.zeros(shape)
for ids, num in enumerate(decoder_input_ids.squeeze()):
if (num.item()+1)%10 == 0:#当遇到以9结尾的数字就停止继续生成数字了
num = 1 #num+1=2是token
maxnum = min(num+1+10,99)
lm_logits[ids,:,num+1:maxnum] = torch.arange(99,99-(maxnum-num-1), step=-1)
return {'logits': lm_logits}
def adjust_logits_during_generation(self, logits, cur_len, max_length):
# if cur_len == 1 and self.config.force_bos_token_to_be_generated:
# self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
# elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
# self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
if cur_len == max_length - 1:
self._force_token_id_to_be_generated(logits, 2)
return logits
@staticmethod
def _force_token_id_to_be_generated(scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
if __name__ == '__main__':
input_ids = torch.randint(0,100,(2,5))
print(input_ids)
decoder = ToyDecoder()
t=decoder._prepare_decoder_input_ids_for_generation(input_ids)
print(t)
t1=decoder.generate(input_ids,8,num_beams=4, num_beam_groups=1, do_sample=False,
length_penalty=1,early_stopping=True, num_return_sequences=4,eos_token_id=2, pad_token_id=1)
print(t1)
结果:
tensor([[91, 50, 26, 71, 23],
[25, 22, 31, 20, 71]])
tensor([[23],
[71]])
tensor([[23, 24, 25, 26, 27, 28, 29, 2],
[23, 24, 25, 26, 27, 29, 2, 1],
[23, 24, 25, 26, 28, 29, 2, 1],
[23, 24, 25, 27, 28, 29, 2, 1],
[71, 72, 73, 74, 75, 76, 77, 2],
[71, 72, 73, 74, 75, 77, 78, 2],
[71, 72, 73, 75, 76, 77, 78, 2],
[71, 72, 74, 75, 76, 77, 78, 2]])