Pytorch加载模型后optimizer.step()报RuntimeError: output with shape...错误

错误背景

存储模型参数后,重新加载接着训练,结果optimizer.step()报如下错误:

...
RuntimeError: output with shape...

例如:

model = NLPModel() # 初始化Model
# model中包含BERT,训练时不修改BERT参数
params = list(set(model.parameters()) - set(model.bert.parameters()))  # 造成错误根本原因
optimizer = torch.optim.Adam(param)

... # 训练代码
optimizer.step() # 没什么问题
... 

# 终止训练,存一下训练状态
torch.save({
    'model': self.model.state_dict(),
    'optimizer': self.optimizer.state_dict(),
}, checkpoint_path)

当下次开始接着上次的训练:

# 加载模型
checkpoint = torch.load(checkpoint_path)
# 加载模型参数
model.load_state_dict(checkpoint['model'])
# 加载optimizer参数
optimizer.load_state_dict(checkpoint['optimizer'])

... # 开始训练
optimizer.step() # 报错
... 

结果在optimizer.step()步骤报错。

错误原因

因为在构建optimizer时对模型参数使用了set()进行包装,

params = list(set(model.parameters()) - set(model.bert.parameters()))  # 造成错误根本原因
optimizer = torch.optim.Adam(param)

set是无序的。这就导致两次的模型参数顺序不一致。进而导致报错

修改方案

不要使用set对参数包装,换一种方式,例如,修改为:

# params = list(set(model.parameters()) - set(model.bert.parameters()))  # 不能这么写
params = []
for key, value in self.named_parameters():
    if not key.startswith("bert."):
        params.append(value)
optimizer = torch.optim.Adam(param)

参考资料

https://github.com/InterDigitalInc/CompressAI/issues/34

你可能感兴趣的:(机器学习,pytorch,深度学习,python)