import torch
from torch import nn
class BoringModel(nn.Sequential):
def __init__(self):
super().__init__()
self.in_proj = nn.Linear(2, 10)
self.stages = nn.Sequential(
nn.Linear(10, 10),
nn.Linear(10, 10)
)
self.out_proj = nn.Linear(10, 2)
1x
显存, x是指模型的大小model = BoringModel()
# model is now in memory
torch.save(model.state_dict(), "./checkpoint.pt")
# our models is now stored on disk
# we need to redefine the model
model = BoringModel()
# 1x memory used
state_dict = torch.load("./checkpoint.pt")
# 2x memory used -> both model and state_dict are in memory!!!
model.load_state_dict(state_dict)
# 1x memory used
我们需要两倍的显存来加载我们之前存储过的权重
如果我们有一个巨大的模型,这是有问题的,因为我们需要两倍的空闲RAM。例如,假设我们有16GB的RAM,而我们的模型使用10GB。加载它需要20GB,我们需要改变我们的策略。
Recently, PyTorch introduced the meta
device. When you put a tensor to the meta device, only its metadata (e.g. shape) are stored, and its values are tossed away. Thus, no space is used.
x = torch.tensor([1])
x
tensor([1])
x.to(torch.device("meta"))
tensor(…, device=‘meta’, size=(1,), dtype=torch.int64)
因此,我们可以通过这种方法使用一倍的显存消耗来加载我们的模型
定义我们的模型 1x
显存
实例化到meta设备上 1x
显存
加载state_dict,1x
显存
replace all empty parameters of our model with the values inside the state_dict 1x
显存
Let’s create the load_state_dict_with_low_memory
function.
from typing import Dict
def load_state_dict_with_low_memory(model: nn.Module, state_dict: Dict[str, torch.Tensor]):
# 通过把模型放到meta设备上来释放一半的显存
model.to(torch.device("meta"))
# 我们需要将state_dict中的每个键关联到一个子模块# we need to associate each key in state_dict to a submodule
# 然后,迭代地使用' state_dict '中的值重新创建所有子模块的参数then, iteratively, re-creat all submodules' parameters with the values in `state_dict`
pass
load_state_dict_with_low_memory(model, {})
model.state_dict()
OrderedDict([('in_proj.weight', tensor(..., device='meta', size=(10, 2))),
('in_proj.bias', tensor(..., device='meta', size=(10,))),
('stages.0.weight', tensor(..., device='meta', size=(10, 10))),
('stages.0.bias', tensor(..., device='meta', size=(10,))),
('stages.1.weight', tensor(..., device='meta', size=(10, 10))),
('stages.1.bias', tensor(..., device='meta', size=(10,))),
('out_proj.weight', tensor(..., device='meta', size=(2, 10))),
('out_proj.bias', tensor(..., device='meta', size=(2,)))])
模型现在是空的。
现在我们必须计算出来自state_dict
的每个参数必须放入模型的哪个submodule of model
中。一种方法是使用[key_in_state_dict] -> [submodule_in_module]创建一个字典。Now we have to figure out in which submodule of model
each parameter from state_dict
has to go. One way to do it is to create a dictionary with [key_in_state_dict]
-> [submodule_in_module]
.
因此,我们知道我们必须将加载的state_dict中的值放在哪里。记住,一旦模型被放置在元设备中,它的所有权重都将被丢弃。
So we know where we have to place the values from the loaded state_dict
. Remember, as soon as the model is placed inside the meta
device, all its weights are tossed away.)
from typing import Dict
def get_keys_to_submodule(model: nn.Module) -> Dict[str, nn.Module]:
keys_to_submodule = {}
# iterate all submodules
for submodule_name, submodule in model.named_modules():
# iterate all paramters in each submobule
for param_name, param in submodule.named_parameters():
# param_name is organized as .. ...
# the more we go deep in the model, the less "subname"s we have
splitted_param_name = param_name.split('.')
# if we have only one subname, then it means that we reach a "leaf" submodule,
# we cannot go inside it anymore. This is the actual parameter
is_leaf_param = len(splitted_param_name) == 1
if is_leaf_param:
# we recreate the correct key
key = f"{submodule_name}.{param_name}"
# we associate this key with this submodule
keys_to_submodule[key] = submodule
return keys_to_submodule
get_keys_to_submodule(model)
现在我们有办法知道哪个键对应’ model 的哪个submodule of model
。让我们回到我们的load_state_dict_with_low_memory
函数并使用来自state_dict
的正确值将每个子模块的参数具体化
def load_state_dict_with_low_memory(model: nn.Module, state_dict: Dict[str, torch.Tensor]):
# free up memory by placing the model in the `meta` device
model.to(torch.device("meta"))
keys_to_submodule = get_keys_to_submodule(model)
for key, submodule in keys_to_submodule.items():
# get the valye from the state_dict
val = state_dict[key]
# we need to substitute the parameter inside submodule,
# remember key is composed of ..
# the actual submodule's parameter is stored inside the
# last subname. If key is `in_proj.weight`, the correct field if `weight`
param_name = key.split('.')[-1]
param_dtype = getattr(submodule, param_name).dtype
val = val.to(param_dtype)
# create a new parameter
new_val = torch.nn.Parameter(val, requires_grad=False))
setattr(submodule, param_name, new_val)
model.state_dict()
load_state_dict_with_low_memory(model, torch.load("checkpoint.pt"))
model.state_dict()
We have successfully loaded our checkpoint inside our model with linear memory consumption!