如何保存MindSpore模型

本文介绍三种MindSpore权重模型保存方法

模型权重直接保存

MindSpore的权重文件称为checkpoint,以ckpt为文件后缀,使用MindSpore提供的统一接save_checkpoint对模型进行保存,代码中model为事先训练好的模型。

import mindspore

ckpt_file_path = './output/best_model_19.ckpt' # 保存路径

# 模型额外相关信息保存,如epoch, batch_size, 
append_info = dict()
append_info['batch_size'] = batch_size
append_info['version'] = mindspore.__version__
append_info['epoch'] = epoch

mindspore.save_checkpoint(model,
                          ckpt_file_path,
                          append_dict=append_info)

在训练时进行保存

MindSpore可以指定在训练结束后自动保存模型文件,需要在callbacks里指定保存的路径:
CheckPoint配置策略:
MindSpore有两种保存CheckPoint策略:迭代策略和时间策略,可以通过创建CheckpointConfig对象设置相应策略。 CheckpointConfig中有四个参数可以自定义设置:

  1. save_checkpoint_steps:表示每隔多少个step保存一个CheckPoint文件,默认值为1。
  2. save_checkpoint_seconds:表示每隔多少秒保存一个CheckPoint文件,默认值为0。
  3. keep_checkpoint_max:表示最多保存多少个CheckPoint文件,默认值为5。
  4. keep_checkpoint_per_n_minutes:表示每隔多少分钟保留一个CheckPoint文件,默认值为0。

用法1:

from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

config_ck = CheckpointConfig(save_checkpoint_steps=5,
                             keep_checkpoint_max=10)
 
ckpoint_cb = ModelCheckpoint(prefix='resnet50',
                             directory=save_path,
                             config=config_ck)
model.train(epoch_num,
            dataset,
            callbacks=ckpoint_cb)

用法2:

kpt_file_path = './output/best_model_19.ckpt' # 保存路径

# 在训练时指定保存路径,训练结束后会对该模型文件进行保存
model.train(epochs,
            dataset_train,
            callbacks=[ValAccMonitor(
                model,
                dataset_val,
                epochs,
                ckpt_directory=ckpt_file_path)])

保存最优模型

此方法可以保证保存的模型权重为在验证集中表现最优的权重:

## 定义单步训练
train_one_step = nn.TrainOneStepCell(net_with_loss, optimizer)

for epoch in range(epochs):
    train_one_epoch(train_one_step, imdb_train, epoch)
    valid_loss = evaluate(net, imdb_valid, loss, epoch)
    
	# 判断在验证集的损失对比目前最小损失值有没有更小
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        # 最终保证
        save_checkpoint(net, ckpt_file_name)

你可能感兴趣的:(MindSpore教程大全,python,深度学习)