pytorch模型保存及加载详解(官网理解翻译+补充)

目录

  • 1. 什么是 state_dict?
  • 2. 为了评估保存加载模型
    • 2.1 保存模型参数 state_dict(建议)
    • 2.2 保存整个模型(并不建议)
  • 3. 为了评估或再训练保存模型
  • 4. 将多个模型保存在一个文件里面
  • 5. 使用来自不同模型的参数进行热启动
  • 6. 在设备之间保存加载模型
    • 6.1 GPU上保存,CPU上加载
    • 6.2 GPU上保存,GPU上加载
    • 6.3 CPU上保存,GPU上加载
    • 6.4 模型在多个GPU并行

1. 什么是 state_dict?

  1. 在pytorch中,一个模型的可学习的参数(权重和偏置)被包含在模型的参数中,可以用model.parameters()查看。一个state_dict就是一个pytorch字典对象,将每个图层映射到他的参数张量上。注意:只有具有可学习参数的层(卷积层,线性层等)和已注册的缓冲区(batchnorm的running_mean)才在模型的state_dict中具有条目。优化器optimizer(torch.optim)也拥有一个state_dict,包含了优化器的状态信息,以及使用的超参数。
  2. 因为state_dict是一个字典。所以可以很方便地进行保存、更新、变更、恢复等等操作。从而为PyTorch模型和优化器增加了很多模块化。

举个栗子:

# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

输出

Model's state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])

Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

2. 为了评估保存加载模型

2.1 保存模型参数 state_dict(建议)

保存:

torch.save(model.state_dict(), PATH)

加载:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
  1. 【注意】PyTorch的1.6版本将torch.save切换为使用新的基于zipfile的文件格式。torch.load()仍然保留着加载旧格式文件的功能。如果想用torch.saved的旧格式,参数里加一句_use_new_zipfile_serialization=False
  2. 为了评估而保存模型,我们就是为了保存那些可学习的参数。用torch.save()保存模型的state_dict给你后面恢复模型提供了最大的灵活度,所以建议这样做。
  3. 一个常见的模型保存文件后缀是.pt 或者是.pth
  4. 注意本小节的标题,评估。所以,在模型加载完成后使用的时候,需要一句model.eval()来固定BN和dropout。
  5. 【注意】load_state_dict()里面的参数是一个字典对象,而不是一个路径。举一个完整一点的例子

假如我保存的时候是这样的

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, PATH)

那么加载的时候就是这样的

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
  1. 【注意】如果仅打算保留性能最佳的模型(利用验证集),不要忘记best_model_state = model.state_dict()返回对状态的引用,而不是它的副本! 必须序列化best_model_state或使用best_model_state = deepcopy(model.state_dict()),否则这个所谓的最好的best_model_state将在后续训练迭代中不断更新。最终模型状态会是过拟合的。

2.2 保存整个模型(并不建议)

保存:

torch.save(model, PATH)

加载:

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
  1. 这个保存/加载方法最直观,代码最少。这种模型保存方式将会用Python的pickle保存整个模块。当然他有缺点,这种方法将序列化的类绑定到特定的类,并且保存模型的时候要是用确切的目录结构(PATH的选取)。
  2. 这样做的原因是pickle不会保存这个模型类本身。而是保存一个路径,这个路径指向一个包含这个模型类的文件,这个路径在被加载的时候调用。所以问题就在这,当这个代码在其他项目中使用或重构后,可能会以各种方式中断。

3. 为了评估或再训练保存模型

保存:

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

加载:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()
  1. 当保存一个checkpoint,不论是为了后面的评估还是为了再训练,单单保存模型的state_dict是不够的,还需要更多的信息才行。就比如优化器的state_dict,因为他因为它包含随着模型训练而更新的缓冲区和参数。其他还有保存这个state_dict时的epoch、loss等等。这样的一个checkpoint往往是单独一个模型的2~3倍。
  2. 想上面那样保存多个组件的时候,文件格式有点不一样,文件后缀是.tar
  3. 加载的时候就不多说了,就按上面那样使用就可以了。

4. 将多个模型保存在一个文件里面

保存:

torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)

加载:

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
  1. 保存包含多个torch.nn.Modules的模型(例如GAN,序列到序列模型或模型集合)时,就按上面的方法就可以。
  2. 保存的文件后缀是.tar
  3. 加载前先初始化,然后直接用torch.load()加载字典。

5. 使用来自不同模型的参数进行热启动

保存:

torch.save(modelA.state_dict(), PATH)

加载:

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
  1. 在迁移学习或训练新的复杂模型时,部分加载模型或加载部分模型是很常见的方案。 利用经过训练的参数,即使只有少数几个可用的参数,也将有助于热启动过程,这与从头开始训练相比,可以更快地收敛模型。
  2. 不论是你想从参数比你少的模型加载,还是从参数比你多的模型加载,让strict这个参数为False就可以了。
  3. 如果你想将一层中的参数加载到另一层,但是有一些关键字不匹配的话,那么就改变要加载的state_dict中的参数关键字来使得关键字匹配。

6. 在设备之间保存加载模型

6.1 GPU上保存,CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
  1. 当在cpu上加载一个gpu上训练的模型的时候,把torch.device(‘cpu’)传给torch.load中map_location参数就可以了。这样的就不用手动将gpu上的tensor改成cpu上的了。map_location参数将张量的存储重映射到cpu设备上。

6.2 GPU上保存,GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors 
# that you feed to the model
  1. 模型加载完后不要忘记放到gpu上面。model.to(torch.device(“cuda”))或者是model.cuda()都可以的。还有就是,gpu上的模型训练的时候要求输入也必须是在gpu上的,所以,后面使用这个模型的时候,要把数据放到gpu上面。input=input.cuda()和input=input.to(torch.device(‘cuda’))都可以的。
  2. 优化器optimizer使用下面的语句放到gpu上面
for state in optimizer.state.values():
    for k, v in state.items():
        if torch.is_tensor(v):
          state[k] = v.cuda()
  1. 【小贴士】模型放到gpu上的时候,没有使用返回值。但是数据不会重写,我们要input=返回值才行。

6.3 CPU上保存,GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
  1. 同样的,由于模型在不同设备上的原因,我们还是需要 map_location这个参数,作用上面已经讲过。这里是将张量重映射到gpu上。

6.4 模型在多个GPU并行

保存:

torch.save(model.module.state_dict(), PATH)

加载:

Load to whatever device you want
  1. torch.nn.DataParallel是一个模型封装程序,可实现并行GPU利用率。 使用model.module.state_dict()来保存一个DataParallel模型。 然后就可以将所需的模型加载到所需的任何设备。
  2. 【小贴士】保存optimizer的时候还是和正常一样的操作,不用加一个module。

你可能感兴趣的:(pytorch,python,pytorch)