pytorch--多卡单卡模型加载

pytorch多卡单卡模型加载

  • 一、 ⇒ l o a d   t o   \xRightarrow{load \ to\ } load to  单卡
  • 二、 ⇒ l o a d   t o   \xRightarrow{load \ to\ } load to  多卡
  • 三、使用

\quad 模型的保存和加载参照 pytorch模型保存及加载详解
\quad 多卡保存的时候,在model的state_dict()参数多了一个"module."的前缀,其他的参数保存的时候单卡多卡保存并没有区别。因此在模型相互加载之前把这个处理好这个前缀就可以了。

一、 ⇒ l o a d   t o   \xRightarrow{load \ to\ } load to  单卡

def strip_prefix(self, state_dict, prefix='module.'):
    if not all(key.startswith(prefix) for key in state_dict.keys()):
        return state_dict
    stripped_state_dict = {}
    for key in list(state_dict.keys()):
        stripped_state_dict[key.replace(prefix, '')] = state_dict.pop(key)
    return stripped_state_dict

二、 ⇒ l o a d   t o   \xRightarrow{load \ to\ } load to  多卡

def add_prefix(self, state_dict, prefix='module.'):
    if all(key.startswith(prefix) for key in state_dict.keys()):
        return state_dict
    stripped_state_dict = {}
    for key in list(state_dict.keys()):
        key2 = prefix + key
        stripped_state_dict[key2] = state_dict.pop(key)
    return stripped_state_dict

三、使用

checkpoint = torch.load(pretrain)
if multi_gpu is not None:
	model.load_state_dict(self.add_prefix(checkpoint['state_dict']))
else:
    model.load_state_dict(self.strip_prefix(checkpoint['state_dict']))
optimizer.load_state_dict(checkpoint['optimizer'])

你可能感兴趣的:(pytorch,pytorch,python,深度学习)