Pytorch nn.Dataparallel model state_dict multi-gpu to single-gpu

Pytorch nn.Dataparallel model state_dict multi-gpu to single-gpu 多块GPU训练的模型转成单块或其他GPU数量需求的模型

标签 : pytorch nn.Dataparalle model.state_dict


参考: reference link

问题描述

我们在用Pytorch训练模型的时候,可能有几组服务器,每个服务器显卡GPU配置和数量不一样,而 nn.Dataparallel保存的模型又是和显卡数量挂钩的,实际上我们需要模型能够随便转移到不同显卡数量的服务器上运行测试。下面的代码正是针对这个问题的

解决方案

重写 nn.Module内的state_dict, load_state_dict函数。也就是说,我们保存和加载的模型是不经过nn.DataParallel处理过的,所以可以在任意GPU数量上进行加载训练的。
如果你已经用多卡训了,那你只需要把下面的代码copy一下,然后再运行一个epoch即可

import torch
import torch.nn as nn
from collections import OrderedDict
from torch.nn.parameter import Parameter

def state_dict(model, destination=None, prefix='', keep_vars=False):
    own_state = model.module if isinstance(model, torch.nn.DataParallel) \
        else model
    if destination is None:
        destination = OrderedDict()
    for name, param in own_state._parameters.items():
        if param is not None:
            destination[prefix + name] = param if keep_vars else param.data
    for name, buf in own_state._buffers.items():
        if buf is not None:
            destination[prefix + name] = buf
    for name, module in own_state._modules.items():
        if module is not None:
            state_dict(module, destination, prefix + name + '.', keep_vars=keep_vars)
    return destination

def load_state_dict(model, state_dict, strict=True):
    own_state = model.module.state_dict() if isinstance(model, torch.nn.DataParallel) \
        else model.state_dict()
    for name, param in state_dict.items():
        if name in own_state:
            if isinstance(param, Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            try:
                own_state[name].copy_(param)
            except Exception:
                raise RuntimeError('While copying the parameter named {}, '
                                    'whose dimensions in the model are {} and '
                                    'whose dimensions in the checkpoint are {}.'
                                    .format(name, own_state[name].size(), param.size()))
        elif strict:
            raise KeyError('unexpected key "{}" in state_dict'
                            .format(name))
    if strict:
        missing = set(own_state.keys()) - set(state_dict.keys())
        if len(missing) > 0:
            raise KeyError('missing keys in state_dict: "{}"'.format(missing))

使用方法:


###use skill

# before
# 以前我们需要模型的state_dict()的时候,是通过下面方法得到的
model.state_dict()

# now
#但是,现在需要运行的是下面的命令
state_dict(model)

#before
# 同样,以前加载模型是通过下面的语句调用的
model.load_state_dict(model.state_dict)

#now
# 但是,现在是这样子的了
your_state_dict=state_dict(model)
load_state_dict(model, your_state_dict) 

希望大家的模型从此牛逼到不行(DZT)

你可能感兴趣的:(meachine,learning)