模型在训练过程中或者在训练之后,模型的执行过程能被保存,也就是,模型能从暂停中恢复以免训练的时间过长。因此,被保存的模型可以被共享,其他人可以重新构建相同的模型。被保存的模型以如下的两种方式进行共享:
|
安装类库
如上所示,安装保存模型的类库、工具集,并导入其支持的函数,其中,以HDF5的文件格式保存模型。
加载数据样本
如上所示,加载mnist训练数据集,该数据集是流行服饰的图片数据集,其中,train_images是训练的图片数据集,train_labels是训练的标签数据集,分别使用mnist数据集的前1000条记录。
定义模型
如上所示,定义创建序列模型的函数、创建模型,其中,该模型仅定义一个全连接层、每层512个单元、使用删除正则化机制、使用Adam优化器优化模型。
保存检查点
检查点是用于保证训练过程的可用性,当训练过程被中断,设置检查点可以保存中断之前的训练数据,例如,保存了训练过程的权重或者参数,用户可以从检查点中恢复训练。Keras技术框架提供tf.keras.callbacks.ModelCheckpoint工具类用于保存训练过程的检查点,用户使用检查点可以持续地保存训练过程或者训练结束的训练数据。
检查点回调(callback)
如上所示,以回调的方式创建了一个模型的检查点,其中,checkpoint_path是保存检查点对应的文件路径,cp_callback是创建一个检查点函数,callbacks=[cp_callback]是为模型设置了训练过程中回调检查点函数、持续地保存训练过程中的检查点,该函数是在每次训练迭代结束的时候被回调,训练多次迭代则多次回调。
如上所示,显示训练结束之后,检查点保存路径下的文件信息。
如上所示,先创建一个模型model,在训练之前对模型执行测试评估,显示其准确度只有7%,使用模型的load_weights函数加载之前训练的检查点,创建一个权重都相同的模型,然后执行测试评估,显示其准确度达到86.60%(重用已训练完成的模型,包括权重)。
检查点选项设置
如上所示,用户可以根据实际情况设置检查点的参数选项,例如,checkpoint_path可以设置每次训练迭代保存的检查点的名称,save_freq设置检查点保存的频率,每多少次训练迭代保存一次检查点,日志显示每5次训练迭代记录一次检查点。
如上所示,显示模型检查点保存的文件列表,以及显示最后一个检查点。
如上所示,使用load_weights函数加载模型的最后一次训练的检查点,其包括该次检查点的权重,然后,执行模型的测试评估,其准确度达到87.30%。
检查点文件属性
检查点文件是二进制格式的字节文件,其中保存了训练过程中学习到的权重信息,其描述如下所示:
|
在单节点的机器中训练模型,则生成的其中一个分片的检查点文件的后缀是.data-00000-of-00001。
手动保存权重
如上所示,使用模型的save_weights保存模型的权重,load_weights函数是加载已保存的权重,然后,使用测试数据集对模型执行测试评估,其准确度达到87.30%。
保存完整模型
使用keras技术框架的tf.keras.Model.save函数可以保存完整的模型,该保存方式包括模型的架构、模型权重、模型训练的设置,该保存方式可以让用户直接重用已保存的模型,而不用修改模型的代码,也就是,由于优化器的状态能被恢复,用户可以直接恢复模型到上次训练暂停的时候。该保存方式的包括两种格式,SavedModel以及HDF5,其中,SavedModel是默认的保存格式。
SavedModel格式
如上所示,使用模型的save函数保存完整的已训练完成的模型。
如上所示,显示保存模型的文件列表,然后,使用load_model函数加载已保存的模型,输出模型的汇总信息。
如上所示,重新加载了已训练完成、完整的、相同的模型,然后,使用模型进行测试评估、预测。
HDF5格式
如上所示,创建一个模型、对模型执行训练、以HDF5的格式保存完整的模型。
如上所示,重新加载HDF5格式的已经保存的模型,输出模型汇总信息。
如上所示,使用测试数据集对重新加载的模型执行测试评估,其准确度达到85.90%。
由以上的分析可知,Keras技术框架保存模型的信息如下所示:
|
(未完待续)