最近遇到一个很奇怪的BUG,好早之前写的一个Bert文本分类模型,拿给别人用的时候,发现不灵了,原本90多的acc,什么都没修改,再测一次发现只剩30多了,检查了一番之后,很快我发现他的transformers版本是4.24,而我一直用的是4.9,没有更新。
于是我试着分析问题出在哪里,然后就遇到了这个坑。首先这是我模型的基础结构,很简单,就是一个Encoder模型加一层分类器:
class BertClassifier(torch.nn.Module):
def __init__(self, bert_model, num_classes):
super(BertClassifier, self).__init__()
self.bert = bert_model
self.dropout = torch.nn.Dropout(0.2)
self.dense = torch.nn.Linear(768, num_classes)
def forward(
self,
input_ids=None,
token_type_ids=None,
attention_mask=None,
labels=None,
):
bert_out = self.bert(input_ids, token_type_ids, attention_mask, output_attentions=False)
# print(list(self.bert.encoder.layer[0].attention.self.query.parameters()))
# print(bert_out)
sequence_output = bert_out.last_hidden_state
print(sequence_output)
sequence_output = self.dropout(sequence_output)
pool_output = torch.mean(sequence_output, axis=1)
logits = self.dense(pool_output)
# print(logits)
loss = None
loss_fct = torch.nn.CrossEntropyLoss()
if labels is not None:
# labels = label.long()
loss = loss_fct(logits, labels.view(-1))
return loss if loss is not None else logits
为了分析问题出在哪里,我把类里的代码全都拿出来,逐行运行,发现最终的logits和正确的logits(在4.9版本的环境里执行的结果)是一致的,这就很奇怪了,但是我实例化模型,再用模型forward出来的结果却是错误的:
# 这个结果计算出来是对的
sequence_output = bert_cls_model.bert(**inputs).last_hidden_state
sequence_output = bert_cls_model.dropout(sequence_output)
pool_output = torch.mean(sequence_output, axis=1)
logits = bert_cls_model.dense(pool_output)
print(logits)
# 这样计算出来是错的
logits = bert_cls_model(**inputs)
print(logits)
于是我又在模型类的定义里打印了各个阶段的结果,如上第一段代码中的print,发现从bert_out的打印结果来看全都是错的。
更进一步地,为了确认是不是模型加载权重的时候出现了问题(比如加载权重后的模型被重新初始化了),我又在模型定义代码里打印了模型的参数值,确认参数值也是没有问题的。这就让我感到有些匪夷所思了。
我又按照同样的对比方法,在模型里边打印一次,单独拿出来打印一次,试着找出问题所在,这次是从一开始embedding开始,结果发现在模型内部和外部打印embedding的结果是一致的:
# 这样打印的结果是正确的
bert_cls_model.bert.embeddings(input_ids=inputs['input_ids'], token_type_ids=inputs['token_type_ids'])
# 在模型的forward方法里打印embedding的结果同样是正确的
更奇怪的是,我将embedding的结果输入给encoder手动计算,出来的sequence_out就变成正确的了:
class BertClassifier(torch.nn.Module):
def __init__(self, bert_model, num_classes):
super(BertClassifier, self).__init__()
self.bert = bert_model
self.dropout = torch.nn.Dropout(0.2)
self.dense = torch.nn.Linear(768, num_classes)
def forward(
self,
input_ids=None,
token_type_ids=None,
attention_mask=None,
labels=None,
):
# 直接调用self.bert计算出来结果是错误的
# bert_out = self.bert(input_ids, token_type_ids, attention_mask, output_attentions=False)
# 手动以此调用embedding和encoder,就算出来的结果就是正确的了
embedding_res = self.bert.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
encoder_out = self.bert.encoder(embedding_res)
sequence_output = encoder_out[0]
sequence_output = self.dropout(sequence_output)
pool_output = torch.mean(sequence_output, axis=1)
logits = self.dense(pool_output)
# print(logits)
loss = None
loss_fct = torch.nn.CrossEntropyLoss()
if labels is not None:
# labels = label.long()
loss = loss_fct(logits, labels.view(-1))
return loss if loss is not None else logits
最后我又额外检查了一遍两个版本源码的差别,也没有发现什么端倪,感觉修改的地方都是些写法的差异,不应该有能够造成这个问题的地方。
解决的话,目前就是把transformers的版本降下来,或者像最后这样手动执行计算,还没有发现真正出问题的地方在哪里,如果有哪位也遇到这个问题并且有效解决了的话,还请在评论区指出,谢谢。