TensorFlow之模型保存与加载

模型在训练过程中或者在训练之后,模型的执行过程能被保存,也就是,模型能从暂停中恢复以免训练的时间过长。因此,被保存的模型可以被共享,其他人可以重新构建相同的模型。被保存的模型以如下的两种方式进行共享:

  •  创建模型的代码

  • 被训练模型的权重或者参数

安装类库

TensorFlow之模型保存与加载_第1张图片

如上所示,安装保存模型的类库、工具集,并导入其支持的函数,其中,以HDF5的文件格式保存模型。

加载数据样本

TensorFlow之模型保存与加载_第2张图片

如上所示,加载mnist训练数据集,该数据集是流行服饰的图片数据集,其中,train_images是训练的图片数据集,train_labels是训练的标签数据集,分别使用mnist数据集的前1000条记录。

定义模型

TensorFlow之模型保存与加载_第3张图片

TensorFlow之模型保存与加载_第4张图片

如上所示,定义创建序列模型的函数、创建模型,其中,该模型仅定义一个全连接层、每层512个单元、使用删除正则化机制、使用Adam优化器优化模型。

保存检查点

检查点是用于保证训练过程的可用性,当训练过程被中断,设置检查点可以保存中断之前的训练数据,例如,保存了训练过程的权重或者参数,用户可以从检查点中恢复训练。Keras技术框架提供tf.keras.callbacks.ModelCheckpoint工具类用于保存训练过程的检查点,用户使用检查点可以持续地保存训练过程或者训练结束的训练数据。

检查点回调(callback)

TensorFlow之模型保存与加载_第5张图片

如上所示,以回调的方式创建了一个模型的检查点,其中,checkpoint_path是保存检查点对应的文件路径,cp_callback是创建一个检查点函数,callbacks=[cp_callback]是为模型设置了训练过程中回调检查点函数、持续地保存训练过程中的检查点,该函数是在每次训练迭代结束的时候被回调,训练多次迭代则多次回调。

TensorFlow之模型保存与加载_第6张图片

如上所示,显示训练结束之后,检查点保存路径下的文件信息。

TensorFlow之模型保存与加载_第7张图片

如上所示,先创建一个模型model,在训练之前对模型执行测试评估,显示其准确度只有7%,使用模型的load_weights函数加载之前训练的检查点,创建一个权重都相同的模型,然后执行测试评估,显示其准确度达到86.60%(重用已训练完成的模型,包括权重)。

检查点选项设置

TensorFlow之模型保存与加载_第8张图片

TensorFlow之模型保存与加载_第9张图片

如上所示,用户可以根据实际情况设置检查点的参数选项,例如,checkpoint_path可以设置每次训练迭代保存的检查点的名称,save_freq设置检查点保存的频率,每多少次训练迭代保存一次检查点,日志显示每5次训练迭代记录一次检查点。

TensorFlow之模型保存与加载_第10张图片

如上所示,显示模型检查点保存的文件列表,以及显示最后一个检查点。

TensorFlow之模型保存与加载_第11张图片

如上所示,使用load_weights函数加载模型的最后一次训练的检查点,其包括该次检查点的权重,然后,执行模型的测试评估,其准确度达到87.30%。

检查点文件属性

检查点文件是二进制格式的字节文件,其中保存了训练过程中学习到的权重信息,其描述如下所示:

  • 一个或者多个分片,每个分片都包含了模型训练所得的权重信息

  • 一个索引文件,对以上的分片文件进行索引

在单节点的机器中训练模型,则生成的其中一个分片的检查点文件的后缀是.data-00000-of-00001。

手动保存权重

TensorFlow之模型保存与加载_第12张图片

如上所示,使用模型的save_weights保存模型的权重,load_weights函数是加载已保存的权重,然后,使用测试数据集对模型执行测试评估,其准确度达到87.30%。

保存完整模型

使用keras技术框架的tf.keras.Model.save函数可以保存完整的模型,该保存方式包括模型的架构、模型权重、模型训练的设置,该保存方式可以让用户直接重用已保存的模型,而不用修改模型的代码,也就是,由于优化器的状态能被恢复,用户可以直接恢复模型到上次训练暂停的时候。该保存方式的包括两种格式,SavedModel以及HDF5,其中,SavedModel是默认的保存格式。

SavedModel格式

TensorFlow之模型保存与加载_第13张图片

如上所示,使用模型的save函数保存完整的已训练完成的模型。

TensorFlow之模型保存与加载_第14张图片

TensorFlow之模型保存与加载_第15张图片

如上所示,显示保存模型的文件列表,然后,使用load_model函数加载已保存的模型,输出模型的汇总信息。

TensorFlow之模型保存与加载_第16张图片

如上所示,重新加载了已训练完成、完整的、相同的模型,然后,使用模型进行测试评估、预测。

HDF5格式

TensorFlow之模型保存与加载_第17张图片

如上所示,创建一个模型、对模型执行训练、以HDF5的格式保存完整的模型。

TensorFlow之模型保存与加载_第18张图片

如上所示,重新加载HDF5格式的已经保存的模型,输出模型汇总信息。

TensorFlow之模型保存与加载_第19张图片

如上所示,使用测试数据集对重新加载的模型执行测试评估,其准确度达到85.90%。

由以上的分析可知,Keras技术框架保存模型的信息如下所示:

  •  模型训练所得的权重值

  • 模型的架构

  • 模型的训练设置,例如,传入到模型compile函数的参数

  • 模型的优化器以及优化器的状态

(未完待续)

你可能感兴趣的:(人工智能技术与架构,tensorflow,深度学习)