比起两年前,NLG任务已经得到了非常有效的发展,transformers模块的使用广泛程度也达到前所未有的程度。在模型推理预测时,一个核心的语句就是model.generate()
,本文就来详细介绍一下generate方法是如何运作的。在生成的过程中,包含了诸多生成策略,本文将以最常用的beam search为例,在本人能力范围内,尽可能详细地展开介绍。
考虑到篇幅可能会比较长,本文将分为上下两篇,上篇主要介绍generate方法的整体结构,下篇将对beam search部分的代码进行进一步的介绍。
随着各种LLM的出现,transformers中与generate相关的代码发生了一些变化,主要区别在于:
在之前版本的transformers中(transformers~=4.9),generate方法位于transformers.generation_utils.py
,这个方法是GenerationMixin
类的一个方法。
而在新版本的transformers中(transformers~=4.28),generate方法被转移到了transformers.generation.utils.py
,仍然是GenerationMixin
的一个类方法。
而对于一个hf形式的预训练模型,都是继承了PreTrainedModel
类的,而顺着这个PreTrainedModel
类,可以看到更上一级的继承逻辑,GenerationMixin
就在其中:
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
这就是为什么通过AutoModel.from_pretrained()
实例化的一个model为什么可以直接调用generate
方法去做推理。
这一部分作为一个速查表写在这里,不建议直接阅读,而是在读第4节代码的过程中,返回来查看这部分内容。
GenerationMixin
类所有方法概览如下:
方法名 | 作用 | 在本文中出现的位置 |
---|---|---|
_validate_model_class | 检修该模型是否可以做生成,并抛出相应的异常 | 4.1 |
_validate_model_kwargs | 检查generation config中的参数是否与生成策略相匹配 | 4.1 |
_prepare_model_inputs | 为生成过程准备输入 | 4.3 |
_maybe_initialize_input_ids_for_generation | 当生成过程的inputs为空时,使用bos token做初始化 | 4.3 |
_prepare_attention_mask_for_generation | 为生成过程准备attention_mask | 4.4 |
_prepare_encoder_decoder_kwargs_for_generation | 为生成过程准备encoder相关的参数 | 4.4 |
_prepare_decoder_input_ids_for_generation | 为自回归模型额外处理input_ids | 4.5 |
_get_decoder_start_token_id | 获取decoder的开始位置的token id,这个id可能与bos不同 | 4.5 |
_get_logits_processor | 创建logits处理器 | 4.8 |
_get_stopping_criteria | 创建停止规则 | 4.9 |
_get_logits_warper | 创建logits warper | 4.11 |
_expand_inputs_for_generation | 根据num_beams对input_ids进行扩展 | 4.12 |
prepare_inputs_for_generation | 对模型的输入进行预处理 | 下篇3.1 |
adjust_logits_during_generation | 在生成过程中对计算的logits进行调整 | 下篇3.1 |
_update_model_kwargs_for_generation | 根据一个step的生成结果,更新生成参数 | 下篇5.6 |
_reorder_cache | 根据step更新的beam_idx,对缓存的past_k_v进行重排 | 下篇5.6 |
在介绍流程之前先看一下generate方法的签名,在4.28版本中,其签名简化如下:
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
相比之前的版本,这样写的直接优点就是,与原版的超长签名相比,减少了传入的参数,将诸如top_k
, top_p
, num_beams
等参数全部都整合到了generation_config
中,使得函数看起来更加简化,并且该参数可以直接从模型路径下的generation_config.json文件中读取,一定程度上为用户提供了便捷。
相应的缺点就是很多参数没有显性地暴露出来,在查看注释和自定义生成配置的时候就不是很方便了。
需要在GenerationConfig
中查看可选的参数:
from transformers.generation.configuration_utils import GenerationConfig
help(GenerationConfig)
(GenerationConfig
中各类生成策略对应的参数各有不同,这里不展开介绍,在本文的下篇中,将对beam search策略下的参数进行简介。)
generate方法的参数含义与作用介绍如下:
参数名 | 类型 | 含义与作用 |
---|---|---|
inputs | torch.Tensor | tokenize之后的序列id,模型将基于这个序列计算logits并进行生成。如果为空,则默认为bos token对应的id |
generation_config | GenerationConfig | 各种生成策略对应的参数,如果为空,将会从模型路径下的generation_config.json文件中读取,或从model config获取 |
logits_processor | LogitsProcessorList | 对模型计算出的logits进行进一步处理,例如对“复读机现象”相应的概率进行惩罚,以避免模型生成结果不断重复 |
stopping_criteria | StoppingCriteriaList | 对生成过程做停止控制的工具,例如达到一定长度时强行停止,达到一定生成时间时停止等 |
prefix_allowed_tokens_fn | [int, torch.Tensor], List[int] | beam search过程中,每个step允许生成的token id范围 |
synced_gpus | bool | 采用DeepSpeed ZeRO时使用 |
streamer | BaseStreamer | stream generate时使用(也就是一个字一个字的往外蹦的效果) |
在这些输入中,logits_processor和stopping_criteria,将是用户手动干预生成过程的主要手段。
在4.28版本的transformers代码中,generate过程的注释写的比较条理清晰,所以本文也沿用代码注释中的序号进行划分。
这一部分的大概逻辑就是处理generation config为None的情况,以及检查是否存在与生成策略不一致的错误参数。
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if self.generation_config._from_model_config:
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation)"
)
self.generation_config = new_generation_config
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
其中_validate_model_class
和_validate_model_kwargs
两个方法都不是重点,这里不展开介绍。
这部分需要补充的参数包括logits_processor
, stopping_criteria
, 以及generation_config
中的pad_token_id
。前两项是设置为默认的空list;pad_token_id没有给定,而eos给定的话,用eos来做padding。
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
这里主要需要关注_prepare_model_inputs
这个方法,这个方法的核心,一句话概括就是模型输入的序列input_ids,必须非空,如果空的话,就用bos_token去初始化。其余部分都是用来应对个别模型的特殊情况:
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
# 这一步似乎是起到一个校准的作用,防止某些encoder-decoder模型的主模型和encoder的输入名称不一致
# 1. retrieve all kwargs that are non-None or non-model input related.
# some encoder-decoder models have different names for model and encoder
if (
self.config.is_encoder_decoder
and hasattr(self, "encoder")
and self.encoder.main_input_name != self.main_input_name
):
input_name = self.encoder.main_input_name
else:
input_name = self.main_input_name
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
# 确保inputs没有重复传入
# 2. check whether model_input_name is passed as kwarg
# if yes and `inputs` is None use kwarg inputs
inputs_kwarg = model_kwargs.pop(input_name, None)
if inputs_kwarg is not None and inputs is not None:
raise ValueError(
f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed."
f"Make sure to either pass {inputs} or {input_name}=..."
)
elif inputs_kwarg is not None:
inputs = inputs_kwarg
# 对于inputs_embeds这一输入参数:
# 如果是decoder-only模型,需要把'input_ids'这一参数放在inputs_kwarg中传入
# 如果是encoder-decoder模型,input_ids与inputs_embeds只能传入其一
# 3. In the presence of `inputs_embeds` for text models:
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
if not self.config.is_encoder_decoder:
has_inputs_embeds_forwarding = "inputs_embeds" in set(
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
)
if not has_inputs_embeds_forwarding:
raise ValueError(
f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs=model_kwargs
)
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
# 如果最后还是没有input_ids, 采用bos创建input_ids,可以简化理解为:
# torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
return inputs, input_name, model_kwargs
这一部分没有需要特别注意的地方,主要就是一些config设置,补齐模型的其他参数,如创建attention_mask,确保encoder-decoder模型能够返回’ModelOutput’类等等。
# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)
# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
if (
generation_config.pad_token_id is not None
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created
# and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name
)
这一步与4.3的主要区别在于,针对AR模型额外进行了处理。如果是encoder-decoder模型,则直接采用4.3创建的input_tensor作为input_ids。
# 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
# 这里主要是针对decoder的开始位置id与bos id不同的情况
# 在这种情况下,decoder-only模型应当以配置中规定的decoder start id开始进行生成
# 此处可简单理解为:torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
model_kwargs=model_kwargs,
device=inputs_tensor.device,
)
# conditional generation for multi-modal models.
if "input_ids" in model_kwargs and model_input_name == "pixel_values":
input_ids = torch.cat([input_ids, model_kwargs.pop("input_ids")], dim=-1)
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
这一部分就是根据config中的相关配置,判断input_id的长度有没有超长。
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
这里直接选择beam search分支了,其他模式不做展开介绍,下同。
beam search分为两种,基础款的beam_gen_mode
,以及进阶款的beam_sample_gen_mode
,其中,前者对应后续的生成方法为beam_search
,后者对应后续的生成方法为beam_sample
。
二者的区别主要在于,进阶款的beam_sample_gen_mode
可以设置temperature、top_k、top_p等参数进一步控制生成,设置的方法在4.11节:logits warper中介绍。对于基础款的beam_gen_mode
,就没有创建logits warper这一环节。
# 7. determine generation mode
is_beam_sample_gen_mode = (
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
这一个环节比较重要,因为涉及到了logits processor。这些processor是在生成的过程中,在每一个step,对计算出来的得分进行修正处理的。在transformers
中,预设了若干processor,用户也可以定义自己的processor(需要继承抽象类transformers.generation.logit_process.LogitsProcessor),自己设计逻辑,来对生成的过程进行人工干预。
在beam search中,logits process的使用方法是:
# 在def beam_sample中使用
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
其中,input_ids是当前step传给模型的序列token id对应Tensor(batch_size, sequence_length),next_token_scores是经过模型计算之后的分数(即在vocab上的概率分布)取log_softmax。
在这里简单介绍一下在transformers
中预设的processor。限于篇幅,不贴出全部源码,只对其功能进行总结。
processor | 作用 | 参考连接 |
---|---|---|
MinLengthLogitsProcessor | 通过将EOS的概率强行设置为0,保证生成结果的长度大于等于一个最小值 | / |
MinNewTokensLengthLogitsProcessor | 与上一个类似,但是prompt的部分不计入生成长度 | / |
RepetitionPenaltyLogitsProcessor | 防止“复读机”现象,给重复出现token添加惩罚,由预训练模型CTRL提出 | arxiv |
EncoderRepetitionPenaltyLogitsProcessor | 与上一个区别在于,生成的结果不能与encoder输入input id重复,而非与当前给定的全部input id | / |
NoRepeatNGramLogitsProcessor | 防止生成的文本中出现重复的n-gram(n个连续的词或字符),区别在于限制连续n个 | github |
EncoderNoRepeatNGramLogitsProcessor | n-gram可以在encoder的input ids中重复,不可以在decoder重复 | github |
NoBadWordsLogitsProcessor | 确保某些词永远不会被生成 | / |
PrefixConstrainedLogitsProcessor | 给定一个prefix_allow_func来限制符合哪些条件的token可以被生成 | arxiv |
HammingDiversityLogitsProcessor | 以Hamming距离为标准,确保生成的各个beam之前的区别足够大 | arxiv |
ForcedBOSTokenLogitsProcessor | 确保生成的第一个token是某个特定的token | / |
ForcedEOSTokenLogitsProcessor | 达到最大长度时,确保以某个特定的token作为结束 | / |
InfNanRemoveLogitsProcessor | 将计算出的得分中,nan替换为0,inf替换为可计算的最大值 | / |
SuppressTokensAtBeginLogitsProcessor | 在达到某个长度之后,将不再生成某些特定的词 | / |
SuppressTokensLogitsProcessor | 将某些特定词的概率设置为-inf,不生成这些词 | / |
ForceTokensLogitsProcessor | 建立一个映射表,把某个token强行映射成另一个token | / |
WhisperTimeStampLogitsProcessor | 强制模型生成时间戳(时间戳是一个特殊token,例如对话中,query=今天是周几?,answer=今天是[timestamp],这个[timestamp]后期会替换成对应的时间) | / |
stopping_criteria与logits_processor是用户对生成过程进行干预的主要手段,相比logits_processor强行改变概率空间,stopping_criteria则是直接设定了终止生成的策略,理解起来也会相对容易一些。
# 9. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
预设的criteria总结如下:
criteria | 作用 |
---|---|
MaxLengthCriteria | 生成的序列达到设置的最大长度时,停止生成 |
MaxNewTokensCriteria | 生成的序列中,除去prompt的部分达到设置的最大长度时,停止生成 |
MaxTimeCriteria | 生成的耗时超过一定时间限制时,停止生成 |
如果是自定义criteria,应当继承抽象类transformers.generation.stopping_criteria.StoppingCriteria
。
这里直接选择进入beam search的分支。如前文所述,如果要控制temperature等超参数,则应该进入is_beam_sample_gen_mode这个分支。
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits warper的使用方法与logits processor一样,都是用来修改概率的输出。关于他们的区别,暂时没有找到很好的解释,可以理解为warper控制着temperature、topk等与生成策略相关的参数。并且是在logits processor处理之后再进行处理的。
普通的beam search不会涉及这一部分,只有选择sample模式的beam search时,才会使用到logits warper。
需要记住的是,它的输入与processor一样,都是当前的序列(token_ids)与之前计算出的得分(scores),返回的结果是处理之后的得分,形状是(batch_size, config.vocab_size)
。
预设的warper包括:
warper | 作用(仅供参考) | 参考链接 |
---|---|---|
TemperatureLogitsWarper | 对score整体除以temperature做缩放 | / |
TopPLogitsWarper | 概率小于topp的得分置为0 | / |
TopKLogitsWarper | 只取topk的概率对应的词汇,其余的概率置为-inf | / |
TypicalLogitsWarper | typical decoding | arxiv |
EpsilonLogitsWarper | 将概率小于epsilon的token移除 | arxiv |
EtaLogitsWarper | eta-sampling | arxiv |
LogitNormalization | 在beam search进行的过程中做layernorm | / |
这一部分是beam search的核心流程,限于篇幅,其具体的执行生成过程将在本文的下篇中进行详细的介绍。
在这一部分中,首先创建了用于打分的BeamSearchScorer(具体作用将在下篇中进行介绍),然后根据num_beams对input_ids进行了扩展,最后执行beam search的核心方法beam_search
,或beam sample对应的beam_sample
方法。
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 13. run beam search
return self.beam_search(
input_ids,
beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
在本文的下篇中,将介绍beam search的基本原理,transformers模块对于beam search的实现方法中,主要涉及的几个工具组件,beam search的生成与更新过程,以及beam sample对beam search的改进实现,感兴趣的朋友可以继续阅读。