transformers库中使用DataParallel保存模型参数时遇到的问题记录

pytorch中使用DataParallel保存模型参数时遇到的问题记录

之前使用Transformers库中的Bert模型在自己的文本分类任务上使用Transformers库里的Trainer方式进行了Fine-tune。今天尝试加载保存好的checkpoint到程序中来直接进行evaluate,结果参数匹配不上,故在此记录问题原因和解决过程。


更新

新的问题

我再某台只有0,1两张显卡的机器上对模型进行了若干次训练,模型直接保存成了一个文件。之后,在重新加载这个模型到另一台拥有0,1,2,3四张显卡的机器上继续进行训练时,发现模型还是只占用0,1两张卡。直接使用

model = nn.DataParallel(model, device_ids=[0,1,2,3])

会报错

RuntimeError: Input, output and indices must be on the current device

查阅相关问题后,找到一种更方便的方式,即使用如下代码:

if isinstance(model,torch.nn.DataParallel):
	model = model.module
# 之后加载就不会报错了
model = nn.DataParallel(model, device_ids=[0,1,2,3])

之后模型就可以正常使用加载了,也可以放到四张卡上。


直接使用AutoModelForSequenceClassification从checkpoint目录加载

Fine tune阶段的模型保存和使用

在Fine tune阶段,我手动将模型设置为了DataParallel模式,然后再使用Trainer进行Fine tune,而模型在保存时也是Trainer自动去保存的。

所设置的Training Args如下:

training_args = TrainingArguments(
    output_dir='./results_0415_32',         # output directory
    num_train_epochs=3,             # total number of training epochs
    per_device_train_batch_size=128, # batch size per divice during training
    per_device_eval_batch_size=128,  # batch size for evaluation
    warmup_steps=500,               # number of warmup steps for learning rate scheduler
    weight_decay=0.01,              # strength of weight decay 用于指定权值衰减率,相当于L2正则化中的λ参数
    logging_dir='./logs',           # directory for storing logs
    logging_steps=10000,
)

然后从一个初始的BertForSequenceClassification模型开始Fine tune,并手动将模型设置了DataParallel模式。

model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(afid2nor))

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model, device_ids=[0,1,2,3])
model.to(device)

此时打印模型即:

DataParallel(
  (module): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (intermediate): BertIntermediate(
              (dense): Linear(in_features=768, out_features=3072, bias=True)
            )
            (output): BertOutput(
              (dense): Linear(in_features=3072, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )

# 1-10层内容省略,基本一致

          (11): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (intermediate): BertIntermediate(
              (dense): Linear(in_features=768, out_features=3072, bias=True)
            )
            (output): BertOutput(
              (dense): Linear(in_features=3072, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
        )
      )
      (pooler): BertPooler(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (activation): Tanh()
      )
    )
    (dropout): Dropout(p=0.1, inplace=False)
    (classifier): Linear(in_features=768, out_features=25762, bias=True)
  )
)

可以看到,这样的模型外面"套着"两层结构为

  (module): BertForSequenceClassification(
    (bert): BertModel(

这个对后续从checkpoint加载产生了影响。

模型的加载

加载的代码为:

model = AutoModelForSequenceClassification.from_pretrained('./results_0415_32/checkpoint-98000/')

提示了如下的Warning:

Some weights of the model checkpoint at ./results_0415_32/checkpoint-98000/ were not used when initializing BertForSequenceClassification: ['module.bert.embeddings.position_ids', 'module.bert.embeddings.word_embeddings.weight', 'module.bert.embeddings.position_embeddings.weight', 'module.bert.embeddings.token_type_embeddings.weight', 'module.bert.embeddings.LayerNorm.weight', 'module.bert.embeddings.LayerNorm.bias', 'module.bert.encoder.layer.0.attention.self.query.weight', 'module.bert.encoder.layer.0.attention.self.query.bias', 'module.bert.encoder.layer.0.attention.self.key.weight', 'module.bert.encoder.layer.0.attention.self.key.bias', 'module.bert.encoder.layer.0.attention.self.value.weight', 'module.bert.encoder.layer.0.attention.self.value.bias', 'module.bert.encoder.layer.0.attention.output.dense.weight', 'module.bert.encoder.layer.0.attention.output.dense.bias', 'module.bert.encoder.layer.0.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.0.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.0.intermediate.dense.weight', 'module.bert.encoder.layer.0.intermediate.dense.bias', 'module.bert.encoder.layer.0.output.dense.weight', 'module.bert.encoder.layer.0.output.dense.bias', 'module.bert.encoder.layer.0.output.LayerNorm.weight', 'module.bert.encoder.layer.0.output.LayerNorm.bias', 'module.bert.encoder.layer.1.attention.self.query.weight', 'module.bert.encoder.layer.1.attention.self.query.bias', 'module.bert.encoder.layer.1.attention.self.key.weight', 'module.bert.encoder.layer.1.attention.self.key.bias', 'module.bert.encoder.layer.1.attention.self.value.weight', 'module.bert.encoder.layer.1.attention.self.value.bias', 'module.bert.encoder.layer.1.attention.output.dense.weight', 'module.bert.encoder.layer.1.attention.output.dense.bias', 'module.bert.encoder.layer.1.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.1.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.1.intermediate.dense.weight', 'module.bert.encoder.layer.1.intermediate.dense.bias', 'module.bert.encoder.layer.1.output.dense.weight', 'module.bert.encoder.layer.1.output.dense.bias', 'module.bert.encoder.layer.1.output.LayerNorm.weight', 'module.bert.encoder.layer.1.output.LayerNorm.bias', 'module.bert.encoder.layer.2.attention.self.query.weight', 'module.bert.encoder.layer.2.attention.self.query.bias', 'module.bert.encoder.layer.2.attention.self.key.weight', 'module.bert.encoder.layer.2.attention.self.key.bias', 'module.bert.encoder.layer.2.attention.self.value.weight', 'module.bert.encoder.layer.2.attention.self.value.bias', 'module.bert.encoder.layer.2.attention.output.dense.weight', 'module.bert.encoder.layer.2.attention.output.dense.bias', 'module.bert.encoder.layer.2.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.2.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.2.intermediate.dense.weight', 'module.bert.encoder.layer.2.intermediate.dense.bias', 'module.bert.encoder.layer.2.output.dense.weight', 'module.bert.encoder.layer.2.output.dense.bias', 'module.bert.encoder.layer.2.output.LayerNorm.weight', 'module.bert.encoder.layer.2.output.LayerNorm.bias', 'module.bert.encoder.layer.3.attention.self.query.weight', 'module.bert.encoder.layer.3.attention.self.query.bias', 'module.bert.encoder.layer.3.attention.self.key.weight', 'module.bert.encoder.layer.3.attention.self.key.bias', 'module.bert.encoder.layer.3.attention.self.value.weight', 'module.bert.encoder.layer.3.attention.self.value.bias', 'module.bert.encoder.layer.3.attention.output.dense.weight', 'module.bert.encoder.layer.3.attention.output.dense.bias', 'module.bert.encoder.layer.3.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.3.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.3.intermediate.dense.weight', 'module.bert.encoder.layer.3.intermediate.dense.bias', 'module.bert.encoder.layer.3.output.dense.weight', 'module.bert.encoder.layer.3.output.dense.bias', 'module.bert.encoder.layer.3.output.LayerNorm.weight', 'module.bert.encoder.layer.3.output.LayerNorm.bias', 'module.bert.encoder.layer.4.attention.self.query.weight', 'module.bert.encoder.layer.4.attention.self.query.bias', 'module.bert.encoder.layer.4.attention.self.key.weight', 'module.bert.encoder.layer.4.attention.self.key.bias', 'module.bert.encoder.layer.4.attention.self.value.weight', 'module.bert.encoder.layer.4.attention.self.value.bias', 'module.bert.encoder.layer.4.attention.output.dense.weight', 'module.bert.encoder.layer.4.attention.output.dense.bias', 'module.bert.encoder.layer.4.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.4.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.4.intermediate.dense.weight', 'module.bert.encoder.layer.4.intermediate.dense.bias', 'module.bert.encoder.layer.4.output.dense.weight', 'module.bert.encoder.layer.4.output.dense.bias', 'module.bert.encoder.layer.4.output.LayerNorm.weight', 'module.bert.encoder.layer.4.output.LayerNorm.bias', 'module.bert.encoder.layer.5.attention.self.query.weight', 'module.bert.encoder.layer.5.attention.self.query.bias', 'module.bert.encoder.layer.5.attention.self.key.weight', 'module.bert.encoder.layer.5.attention.self.key.bias', 'module.bert.encoder.layer.5.attention.self.value.weight', 'module.bert.encoder.layer.5.attention.self.value.bias', 'module.bert.encoder.layer.5.attention.output.dense.weight', 'module.bert.encoder.layer.5.attention.output.dense.bias', 'module.bert.encoder.layer.5.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.5.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.5.intermediate.dense.weight', 'module.bert.encoder.layer.5.intermediate.dense.bias', 'module.bert.encoder.layer.5.output.dense.weight', 'module.bert.encoder.layer.5.output.dense.bias', 'module.bert.encoder.layer.5.output.LayerNorm.weight', 'module.bert.encoder.layer.5.output.LayerNorm.bias', 'module.bert.encoder.layer.6.attention.self.query.weight', 'module.bert.encoder.layer.6.attention.self.query.bias', 'module.bert.encoder.layer.6.attention.self.key.weight', 'module.bert.encoder.layer.6.attention.self.key.bias', 'module.bert.encoder.layer.6.attention.self.value.weight', 'module.bert.encoder.layer.6.attention.self.value.bias', 'module.bert.encoder.layer.6.attention.output.dense.weight', 'module.bert.encoder.layer.6.attention.output.dense.bias', 'module.bert.encoder.layer.6.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.6.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.6.intermediate.dense.weight', 'module.bert.encoder.layer.6.intermediate.dense.bias', 'module.bert.encoder.layer.6.output.dense.weight', 'module.bert.encoder.layer.6.output.dense.bias', 'module.bert.encoder.layer.6.output.LayerNorm.weight', 'module.bert.encoder.layer.6.output.LayerNorm.bias', 'module.bert.encoder.layer.7.attention.self.query.weight', 'module.bert.encoder.layer.7.attention.self.query.bias', 'module.bert.encoder.layer.7.attention.self.key.weight', 'module.bert.encoder.layer.7.attention.self.key.bias', 'module.bert.encoder.layer.7.attention.self.value.weight', 'module.bert.encoder.layer.7.attention.self.value.bias', 'module.bert.encoder.layer.7.attention.output.dense.weight', 'module.bert.encoder.layer.7.attention.output.dense.bias', 'module.bert.encoder.layer.7.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.7.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.7.intermediate.dense.weight', 'module.bert.encoder.layer.7.intermediate.dense.bias', 'module.bert.encoder.layer.7.output.dense.weight', 'module.bert.encoder.layer.7.output.dense.bias', 'module.bert.encoder.layer.7.output.LayerNorm.weight', 'module.bert.encoder.layer.7.output.LayerNorm.bias', 'module.bert.encoder.layer.8.attention.self.query.weight', 'module.bert.encoder.layer.8.attention.self.query.bias', 'module.bert.encoder.layer.8.attention.self.key.weight', 'module.bert.encoder.layer.8.attention.self.key.bias', 'module.bert.encoder.layer.8.attention.self.value.weight', 'module.bert.encoder.layer.8.attention.self.value.bias', 'module.bert.encoder.layer.8.attention.output.dense.weight', 'module.bert.encoder.layer.8.attention.output.dense.bias', 'module.bert.encoder.layer.8.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.8.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.8.intermediate.dense.weight', 'module.bert.encoder.layer.8.intermediate.dense.bias', 'module.bert.encoder.layer.8.output.dense.weight', 'module.bert.encoder.layer.8.output.dense.bias', 'module.bert.encoder.layer.8.output.LayerNorm.weight', 'module.bert.encoder.layer.8.output.LayerNorm.bias', 'module.bert.encoder.layer.9.attention.self.query.weight', 'module.bert.encoder.layer.9.attention.self.query.bias', 'module.bert.encoder.layer.9.attention.self.key.weight', 'module.bert.encoder.layer.9.attention.self.key.bias', 'module.bert.encoder.layer.9.attention.self.value.weight', 'module.bert.encoder.layer.9.attention.self.value.bias', 'module.bert.encoder.layer.9.attention.output.dense.weight', 'module.bert.encoder.layer.9.attention.output.dense.bias', 'module.bert.encoder.layer.9.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.9.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.9.intermediate.dense.weight', 'module.bert.encoder.layer.9.intermediate.dense.bias', 'module.bert.encoder.layer.9.output.dense.weight', 'module.bert.encoder.layer.9.output.dense.bias', 'module.bert.encoder.layer.9.output.LayerNorm.weight', 'module.bert.encoder.layer.9.output.LayerNorm.bias', 'module.bert.encoder.layer.10.attention.self.query.weight', 'module.bert.encoder.layer.10.attention.self.query.bias', 'module.bert.encoder.layer.10.attention.self.key.weight', 'module.bert.encoder.layer.10.attention.self.key.bias', 'module.bert.encoder.layer.10.attention.self.value.weight', 'module.bert.encoder.layer.10.attention.self.value.bias', 'module.bert.encoder.layer.10.attention.output.dense.weight', 'module.bert.encoder.layer.10.attention.output.dense.bias', 'module.bert.encoder.layer.10.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.10.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.10.intermediate.dense.weight', 'module.bert.encoder.layer.10.intermediate.dense.bias', 'module.bert.encoder.layer.10.output.dense.weight', 'module.bert.encoder.layer.10.output.dense.bias', 'module.bert.encoder.layer.10.output.LayerNorm.weight', 'module.bert.encoder.layer.10.output.LayerNorm.bias', 'module.bert.encoder.layer.11.attention.self.query.weight', 'module.bert.encoder.layer.11.attention.self.query.bias', 'module.bert.encoder.layer.11.attention.self.key.weight', 'module.bert.encoder.layer.11.attention.self.key.bias', 'module.bert.encoder.layer.11.attention.self.value.weight', 'module.bert.encoder.layer.11.attention.self.value.bias', 'module.bert.encoder.layer.11.attention.output.dense.weight', 'module.bert.encoder.layer.11.attention.output.dense.bias', 'module.bert.encoder.layer.11.attention.output.LayerNorm.weight', 'module.bert.encoder.layer.11.attention.output.LayerNorm.bias', 'module.bert.encoder.layer.11.intermediate.dense.weight', 'module.bert.encoder.layer.11.intermediate.dense.bias', 'module.bert.encoder.layer.11.output.dense.weight', 'module.bert.encoder.layer.11.output.dense.bias', 'module.bert.encoder.layer.11.output.LayerNorm.weight', 'module.bert.encoder.layer.11.output.LayerNorm.bias', 'module.bert.pooler.dense.weight', 'module.bert.pooler.dense.bias', 'module.classifier.weight', 'module.classifier.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./results_0415_32/checkpoint-98000/ and are newly initialized: ['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

可以看到,提示有很多的参数无法匹配上,原因是在训练时使用了DataParallel方式,导致参数的前面都多了一个module.bert.的前缀。

解决方法

将模型参数重新加载与输出

目前一种实验可行的方法较为繁琐,主要有以下几个步骤

  1. 首先加载checkpoint目录下的pytorch_model.bin文件,这里面存储了在DataParallel方式下模型的各个参数。
  2. 生成一个与原始Fine-tune模型相同结构的模型,并将其也使用DataParallel方式放到设备上。
  3. 使用model.load_state_dict方法把参数加载进来
  4. 将模型放到一张显卡上(即退出了DataParallel模式)
  5. 重新保存模型参数到一个本地文件中。
  6. 新建一个与原始Fine-tune模型相同结构的模型,并将其放到一张GPU设备上。
  7. 读入在步骤5中保存的参数。

相关代码如下:

# Step 1
pytorch_model_state_dict = torch.load('./results_0415_32/checkpoint-98000/pytorch_model.bin', map_location=torch.device('cpu'))

# Step 2
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=CLASS_NUM)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model, device_ids=[0,1,2,3])
model.to(device)

# Step 3
model.load_state_dict(pytorch_model_state_dict)

# Step 4
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)

# Step 5
torch.save(model.module.state_dict(), 'static_dict.pkl')

# Step 6
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=CLASS_NUM)
model.to(device)

# Step 7
model.load_state_dict(torch.load('static_dict.pkl'))

使用Trainer类时不要认为设置模型为DataParallel模式

这种方式加载和保存都非常繁琐,在与同学讨论后,了解到其实Trainer这个类会默认采用DataParallel方式去训练,因而并不需要人为的将模型设置DataParallel,这样在加载时就不会导致参数前嵌套的两层结构造成参数不匹配。

实验测试有效,完全匹配而不会抱任何错误和警告。

但由于默认的DataParallel模式总是会将模型放到这个机器的所有GPU上来并行加速。实际中,由于多用户使用,可能四卡机中的cuda 0和cuda 2被占用,那么默认启动Trainer时它就会请求cuda 0给予GPU memory,但由于cuda 0被占用,程序将意外终止为cuda out of memory。

多卡的选择

实际上,我们可以通过设置哪些GPU让程序可见来解决上述问题。具体做法为在命令行中程序运行时,设置可见的GPU即可。例如:

CUDA_VISIBLE_DEVICES=1,3 python my_script.py

就可以让程序只看见cuda 1和cuda 3并且程序内部所看到的cuda编号为cuda 0和cuda 1,即该程序会认为这个机器上只有两个GPU。

参考

  1. (原)PyTorch中使用指定的GPU, https://www.cnblogs.com/darkknightzh/p/6836568.html

你可能感兴趣的:(Pytorch,自然语言处理,NLP)