之前使用Transformers库中的Bert模型在自己的文本分类任务上使用Transformers库里的Trainer方式进行了Fine-tune。今天尝试加载保存好的checkpoint到程序中来直接进行evaluate,结果参数匹配不上,故在此记录问题原因和解决过程。
我再某台只有0,1两张显卡的机器上对模型进行了若干次训练,模型直接保存成了一个文件。之后,在重新加载这个模型到另一台拥有0,1,2,3四张显卡的机器上继续进行训练时,发现模型还是只占用0,1两张卡。直接使用
model = nn.DataParallel(model, device_ids=[0,1,2,3])
会报错
RuntimeError: Input, output and indices must be on the current device
查阅相关问题后,找到一种更方便的方式,即使用如下代码:
if isinstance(model,torch.nn.DataParallel):
model = model.module
# 之后加载就不会报错了
model = nn.DataParallel(model, device_ids=[0,1,2,3])
之后模型就可以正常使用加载了,也可以放到四张卡上。
在Fine tune阶段,我手动将模型设置为了DataParallel模式,然后再使用Trainer进行Fine tune,而模型在保存时也是Trainer自动去保存的。
所设置的Training Args如下:
training_args = TrainingArguments(
output_dir='./results_0415_32', # output directory
num_train_epochs=3, # total number of training epochs
per_device_train_batch_size=128, # batch size per divice during training
per_device_eval_batch_size=128, # batch size for evaluation
warmup_steps=500, # number of warmup steps for learning rate scheduler
weight_decay=0.01, # strength of weight decay 用于指定权值衰减率,相当于L2正则化中的λ参数
logging_dir='./logs', # directory for storing logs
logging_steps=10000,
)
然后从一个初始的BertForSequenceClassification模型开始Fine tune,并手动将模型设置了DataParallel模式。
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(afid2nor))
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model, device_ids=[0,1,2,3])
model.to(device)
此时打印模型即:
DataParallel(
(module): BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
# 1-10层内容省略,基本一致
(11): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=768, out_features=25762, bias=True)
)
)
可以看到,这样的模型外面"套着"两层结构为
(module): BertForSequenceClassification(
(bert): BertModel(
这个对后续从checkpoint加载产生了影响。
加载的代码为:
model = AutoModelForSequenceClassification.from_pretrained('./results_0415_32/checkpoint-98000/')
提示了如下的Warning:
Some weights of the model checkpoint at ./results_0415_32/checkpoint-98000/ were not used when initializing BertForSequenceClassification: ['module.bert.embeddings.position_ids', 'module.bert.embeddings.word_embeddings.weight', 'module.bert.embeddings.position_embeddings.weight', 'module.bert.embeddings.token_type_embeddings.weight', 'module.bert.embeddings.LayerNorm.weight', 'module.bert.embeddings.LayerNorm.bias', 'module.bert.encoder.layer.0.attention.self.query.weight', 'module.bert.encoder.layer.0.attention.self.query.bias', 'module.bert.encoder.layer.0.attention.self.key.weight', 'module.bert.encoder.layer.0.attention.self.key.bias', 'module.bert.encoder.layer.0.attention.self.value.weight', 'module.bert.encoder.layer.0.attention.self.value.bias', 'module.bert.encoder.layer.0.attention.output.dense.weight', 'module.bert.encoder.layer.0.attention.output.dense.bias', 'module.bert.encoder.layer.0.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.0.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.0.intermediate.dense.weight', 'module.bert.encoder.layer.0.intermediate.dense.bias', 'module.bert.encoder.layer.0.output.dense.weight', 'module.bert.encoder.layer.0.output.dense.bias', 'module.bert.encoder.layer.0.output.LayerNorm.weight', 'module.bert.encoder.layer.0.output.LayerNorm.bias', 'module.bert.encoder.layer.1.attention.self.query.weight', 'module.bert.encoder.layer.1.attention.self.query.bias', 'module.bert.encoder.layer.1.attention.self.key.weight', 'module.bert.encoder.layer.1.attention.self.key.bias', 'module.bert.encoder.layer.1.attention.self.value.weight', 'module.bert.encoder.layer.1.attention.self.value.bias', 'module.bert.encoder.layer.1.attention.output.dense.weight', 'module.bert.encoder.layer.1.attention.output.dense.bias', 'module.bert.encoder.layer.1.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.1.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.1.intermediate.dense.weight', 'module.bert.encoder.layer.1.intermediate.dense.bias', 'module.bert.encoder.layer.1.output.dense.weight', 'module.bert.encoder.layer.1.output.dense.bias', 'module.bert.encoder.layer.1.output.LayerNorm.weight', 'module.bert.encoder.layer.1.output.LayerNorm.bias', 'module.bert.encoder.layer.2.attention.self.query.weight', 'module.bert.encoder.layer.2.attention.self.query.bias', 'module.bert.encoder.layer.2.attention.self.key.weight', 'module.bert.encoder.layer.2.attention.self.key.bias', 'module.bert.encoder.layer.2.attention.self.value.weight', 'module.bert.encoder.layer.2.attention.self.value.bias', 'module.bert.encoder.layer.2.attention.output.dense.weight', 'module.bert.encoder.layer.2.attention.output.dense.bias', 'module.bert.encoder.layer.2.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.2.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.2.intermediate.dense.weight', 'module.bert.encoder.layer.2.intermediate.dense.bias', 'module.bert.encoder.layer.2.output.dense.weight', 'module.bert.encoder.layer.2.output.dense.bias', 'module.bert.encoder.layer.2.output.LayerNorm.weight', 'module.bert.encoder.layer.2.output.LayerNorm.bias', 'module.bert.encoder.layer.3.attention.self.query.weight', 'module.bert.encoder.layer.3.attention.self.query.bias', 'module.bert.encoder.layer.3.attention.self.key.weight', 'module.bert.encoder.layer.3.attention.self.key.bias', 'module.bert.encoder.layer.3.attention.self.value.weight', 'module.bert.encoder.layer.3.attention.self.value.bias', 'module.bert.encoder.layer.3.attention.output.dense.weight', 'module.bert.encoder.layer.3.attention.output.dense.bias', 'module.bert.encoder.layer.3.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.3.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.3.intermediate.dense.weight', 'module.bert.encoder.layer.3.intermediate.dense.bias', 'module.bert.encoder.layer.3.output.dense.weight', 'module.bert.encoder.layer.3.output.dense.bias', 'module.bert.encoder.layer.3.output.LayerNorm.weight', 'module.bert.encoder.layer.3.output.LayerNorm.bias', 'module.bert.encoder.layer.4.attention.self.query.weight', 'module.bert.encoder.layer.4.attention.self.query.bias', 'module.bert.encoder.layer.4.attention.self.key.weight', 'module.bert.encoder.layer.4.attention.self.key.bias', 'module.bert.encoder.layer.4.attention.self.value.weight', 'module.bert.encoder.layer.4.attention.self.value.bias', 'module.bert.encoder.layer.4.attention.output.dense.weight', 'module.bert.encoder.layer.4.attention.output.dense.bias', 'module.bert.encoder.layer.4.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.4.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.4.intermediate.dense.weight', 'module.bert.encoder.layer.4.intermediate.dense.bias', 'module.bert.encoder.layer.4.output.dense.weight', 'module.bert.encoder.layer.4.output.dense.bias', 'module.bert.encoder.layer.4.output.LayerNorm.weight', 'module.bert.encoder.layer.4.output.LayerNorm.bias', 'module.bert.encoder.layer.5.attention.self.query.weight', 'module.bert.encoder.layer.5.attention.self.query.bias', 'module.bert.encoder.layer.5.attention.self.key.weight', 'module.bert.encoder.layer.5.attention.self.key.bias', 'module.bert.encoder.layer.5.attention.self.value.weight', 'module.bert.encoder.layer.5.attention.self.value.bias', 'module.bert.encoder.layer.5.attention.output.dense.weight', 'module.bert.encoder.layer.5.attention.output.dense.bias', 'module.bert.encoder.layer.5.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.5.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.5.intermediate.dense.weight', 'module.bert.encoder.layer.5.intermediate.dense.bias', 'module.bert.encoder.layer.5.output.dense.weight', 'module.bert.encoder.layer.5.output.dense.bias', 'module.bert.encoder.layer.5.output.LayerNorm.weight', 'module.bert.encoder.layer.5.output.LayerNorm.bias', 'module.bert.encoder.layer.6.attention.self.query.weight', 'module.bert.encoder.layer.6.attention.self.query.bias', 'module.bert.encoder.layer.6.attention.self.key.weight', 'module.bert.encoder.layer.6.attention.self.key.bias', 'module.bert.encoder.layer.6.attention.self.value.weight', 'module.bert.encoder.layer.6.attention.self.value.bias', 'module.bert.encoder.layer.6.attention.output.dense.weight', 'module.bert.encoder.layer.6.attention.output.dense.bias', 'module.bert.encoder.layer.6.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.6.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.6.intermediate.dense.weight', 'module.bert.encoder.layer.6.intermediate.dense.bias', 'module.bert.encoder.layer.6.output.dense.weight', 'module.bert.encoder.layer.6.output.dense.bias', 'module.bert.encoder.layer.6.output.LayerNorm.weight', 'module.bert.encoder.layer.6.output.LayerNorm.bias', 'module.bert.encoder.layer.7.attention.self.query.weight', 'module.bert.encoder.layer.7.attention.self.query.bias', 'module.bert.encoder.layer.7.attention.self.key.weight', 'module.bert.encoder.layer.7.attention.self.key.bias', 'module.bert.encoder.layer.7.attention.self.value.weight', 'module.bert.encoder.layer.7.attention.self.value.bias', 'module.bert.encoder.layer.7.attention.output.dense.weight', 'module.bert.encoder.layer.7.attention.output.dense.bias', 'module.bert.encoder.layer.7.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.7.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.7.intermediate.dense.weight', 'module.bert.encoder.layer.7.intermediate.dense.bias', 'module.bert.encoder.layer.7.output.dense.weight', 'module.bert.encoder.layer.7.output.dense.bias', 'module.bert.encoder.layer.7.output.LayerNorm.weight', 'module.bert.encoder.layer.7.output.LayerNorm.bias', 'module.bert.encoder.layer.8.attention.self.query.weight', 'module.bert.encoder.layer.8.attention.self.query.bias', 'module.bert.encoder.layer.8.attention.self.key.weight', 'module.bert.encoder.layer.8.attention.self.key.bias', 'module.bert.encoder.layer.8.attention.self.value.weight', 'module.bert.encoder.layer.8.attention.self.value.bias', 'module.bert.encoder.layer.8.attention.output.dense.weight', 'module.bert.encoder.layer.8.attention.output.dense.bias', 'module.bert.encoder.layer.8.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.8.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.8.intermediate.dense.weight', 'module.bert.encoder.layer.8.intermediate.dense.bias', 'module.bert.encoder.layer.8.output.dense.weight', 'module.bert.encoder.layer.8.output.dense.bias', 'module.bert.encoder.layer.8.output.LayerNorm.weight', 'module.bert.encoder.layer.8.output.LayerNorm.bias', 'module.bert.encoder.layer.9.attention.self.query.weight', 'module.bert.encoder.layer.9.attention.self.query.bias', 'module.bert.encoder.layer.9.attention.self.key.weight', 'module.bert.encoder.layer.9.attention.self.key.bias', 'module.bert.encoder.layer.9.attention.self.value.weight', 'module.bert.encoder.layer.9.attention.self.value.bias', 'module.bert.encoder.layer.9.attention.output.dense.weight', 'module.bert.encoder.layer.9.attention.output.dense.bias', 'module.bert.encoder.layer.9.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.9.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.9.intermediate.dense.weight', 'module.bert.encoder.layer.9.intermediate.dense.bias', 'module.bert.encoder.layer.9.output.dense.weight', 'module.bert.encoder.layer.9.output.dense.bias', 'module.bert.encoder.layer.9.output.LayerNorm.weight', 'module.bert.encoder.layer.9.output.LayerNorm.bias', 'module.bert.encoder.layer.10.attention.self.query.weight', 'module.bert.encoder.layer.10.attention.self.query.bias', 'module.bert.encoder.layer.10.attention.self.key.weight', 'module.bert.encoder.layer.10.attention.self.key.bias', 'module.bert.encoder.layer.10.attention.self.value.weight', 'module.bert.encoder.layer.10.attention.self.value.bias', 'module.bert.encoder.layer.10.attention.output.dense.weight', 'module.bert.encoder.layer.10.attention.output.dense.bias', 'module.bert.encoder.layer.10.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.10.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.10.intermediate.dense.weight', 'module.bert.encoder.layer.10.intermediate.dense.bias', 'module.bert.encoder.layer.10.output.dense.weight', 'module.bert.encoder.layer.10.output.dense.bias', 'module.bert.encoder.layer.10.output.LayerNorm.weight', 'module.bert.encoder.layer.10.output.LayerNorm.bias', 'module.bert.encoder.layer.11.attention.self.query.weight', 'module.bert.encoder.layer.11.attention.self.query.bias', 'module.bert.encoder.layer.11.attention.self.key.weight', 'module.bert.encoder.layer.11.attention.self.key.bias', 'module.bert.encoder.layer.11.attention.self.value.weight', 'module.bert.encoder.layer.11.attention.self.value.bias', 'module.bert.encoder.layer.11.attention.output.dense.weight', 'module.bert.encoder.layer.11.attention.output.dense.bias', 'module.bert.encoder.layer.11.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.11.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.11.intermediate.dense.weight', 'module.bert.encoder.layer.11.intermediate.dense.bias', 'module.bert.encoder.layer.11.output.dense.weight', 'module.bert.encoder.layer.11.output.dense.bias', 'module.bert.encoder.layer.11.output.LayerNorm.weight', 'module.bert.encoder.layer.11.output.LayerNorm.bias', 'module.bert.pooler.dense.weight', 'module.bert.pooler.dense.bias', 'module.classifier.weight', 'module.classifier.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./results_0415_32/checkpoint-98000/ and are newly initialized: ['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
可以看到,提示有很多的参数无法匹配上,原因是在训练时使用了DataParallel方式,导致参数的前面都多了一个module.bert.的前缀。
目前一种实验可行的方法较为繁琐,主要有以下几个步骤
相关代码如下:
# Step 1
pytorch_model_state_dict = torch.load('./results_0415_32/checkpoint-98000/pytorch_model.bin', map_location=torch.device('cpu'))
# Step 2
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=CLASS_NUM)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model, device_ids=[0,1,2,3])
model.to(device)
# Step 3
model.load_state_dict(pytorch_model_state_dict)
# Step 4
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)
# Step 5
torch.save(model.module.state_dict(), 'static_dict.pkl')
# Step 6
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=CLASS_NUM)
model.to(device)
# Step 7
model.load_state_dict(torch.load('static_dict.pkl'))
这种方式加载和保存都非常繁琐,在与同学讨论后,了解到其实Trainer这个类会默认采用DataParallel方式去训练,因而并不需要人为的将模型设置DataParallel,这样在加载时就不会导致参数前嵌套的两层结构造成参数不匹配。
实验测试有效,完全匹配而不会抱任何错误和警告。
但由于默认的DataParallel模式总是会将模型放到这个机器的所有GPU上来并行加速。实际中,由于多用户使用,可能四卡机中的cuda 0和cuda 2被占用,那么默认启动Trainer时它就会请求cuda 0给予GPU memory,但由于cuda 0被占用,程序将意外终止为cuda out of memory。
实际上,我们可以通过设置哪些GPU让程序可见来解决上述问题。具体做法为在命令行中程序运行时,设置可见的GPU即可。例如:
CUDA_VISIBLE_DEVICES=1,3 python my_script.py
就可以让程序只看见cuda 1和cuda 3并且程序内部所看到的cuda编号为cuda 0和cuda 1,即该程序会认为这个机器上只有两个GPU。