1. The following model_kwargs are not used by the model: ['encoder_hidden_states', 'encoder_attention_mask'] (note: typos in the generate arguments will also show up in this list)
使用text_decoder就出现上述错误,这是由于transformers版本不兼容导致的
from transformers import AutoModel, AutoConfig, BertGenerationDecoder
decoder_config = AutoConfig.from_pretrained(args['text_checkpoint'])
text_decoder = BertGenerationDecoder(config=decoder_config)
output = self.text_decoder.generate(input_ids=cls_input_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
max_length=self.args['max_seq_length'],
do_sample=True,
num_beams=self.args['beam_size'],
length_penalty=1.0, use_cache=True,
)
解决办法:将transformer的版本换到以下范围, 4.15.0<=transformers<4.22.0,transformers>=4.25.0
比如:pip install transformers==4.25.1 or pip install transformers==4.20.1
2. No module named 'transformers.generation_beam_constraints' (其中transformers==4.28.1)
(1)解决办法
将:from transformers import generation_beam_constraints
改为:from transformers.generation import beam_constraints
(2)其他例子
有问题的代码:
# 可以在transformers == 4.23.1版本上面运行
from transformers.generation_beam_constraints import Constraint
from transformers.generation_beam_search import BeamScorer, BeamSearchScorer
from transformers.generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from transformers.generation_stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
修正后的代码:
# 可以在transformers == 4.28.1版本上面运行
from transformers.generation.beam_constraints import Constraint
from transformers.generation.beam_search import BeamScorer, BeamSearchScorer
from transformers.generation.logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from transformers.generation.stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)