保存与加载模型(minist数据集)

保存与加载模型(minist数据集)

导入库&数据预处理

import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

建个简单模型

# build a model
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

保存与加载模型(minist数据集)_第1张图片

在训练期间保存模型

import os
checkpoint_path = "快速入门01/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# 创建一个保存模型权重的回调
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,                                                 verbose=2)

# 使用新的回调训练模型
model.fit(x_train,
         y_train,
         epochs=10,
         validation_data=(x_test,y_test),
         callbacks=[cp_callback])

保存与加载模型(minist数据集)_第2张图片

预测一下

保存与加载模型(minist数据集)_第3张图片

你可能感兴趣的:(tensorflow与图像处理)