使用MindSpore训练及保存模型

MindSpore提供了回调Callback机制,可以在训练过程中执行自定义逻辑,这里以使用框架提供的ModelCheckpoint为例。

  1. ModelCheckpoint可以保存网络模型和参数,以便进行后续的Fine-tuning(微调)操作。

    from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

    # 设置模型保存参数

    config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)

    # 应用模型保存参数

    ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)

  2. 通过MindSpore提供的model.train接口可以方便地进行网络的训练,LossMonitor可以监控训练过程中loss值的变化。

    # 导入模型训练需要的库

    from mindspore.nn import Accuracy

    from mindspore.train.callback import LossMonitor

    from mindspore import Model

  3. def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):

        """定义训练的方法"""

        # 加载训练数据集

        ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)

        model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode)

  4. 其中,dataset_sink_mode用于控制数据是否下沉,数据下沉是指数据通过通道直接传送到Device上,可以加快训练速度,dataset_sink_mode为True表示数据下沉,否则为非下沉。

    通过模型运行测试数据集得到的结果,验证模型的泛化能力。

    使用model.eval接口读入测试数据集。

    使用保存后的模型参数进行推理。

  5. def test_net(network, model, data_path):

        """定义验证的方法"""

        ds_eval = create_dataset(os.path.join(data_path, "test"))

        acc = model.eval(ds_eval, dataset_sink_mode=False)

        print("{}".format(acc))

  6. 这里把train_epoch设置为1,对数据集进行1个迭代的训练。在train_net和 test_net方法中,我们加载了之前下载的训练数据集,mnist_path是MNIST数据集路径。

    train_epoch = 1

    mnist_path = "./datasets/MNIST_Data"

    dataset_size = 1

    model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

    train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False)

    test_net(net, model, mnist_path)

  7. 使用以下命令运行脚本:

    python lenet.py --device_target=CPU

  8. 其中,

    lenet.py:为你根据教程编写的脚本文件。

    --device_target=CPU:指定运行硬件平台,参数为CPU、GPU或者Ascend,根据你的实际运行硬件平台来指定。

你可能感兴趣的:(深度学习,tensorflow,python)