这节课,我们介绍三种保存模型的方法,另外介绍两个很有用的工具,一个是游乐场,一个是tensorboard,这里只是浅浅带过,以后会深入讨论
昨天没更新,属实是累了,下一篇卷积神经网络,冲冲冲
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# step2 创建模型
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
y=y_train,
epochs=1,
)
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2195 - accuracy: 0.9350
313/313 [==============================] - 1s 2ms/step - loss: 0.1053 - accuracy: 0.9678
train model, accuracy:96.78%
# step6 保存模型的权重和偏置
model.save_weights('./min.h5')
# step7 删除模型
del model
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# step9 恢复权重
model.load_weights('./min.h5')
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))
313/313 [==============================] - 1s 2ms/step - loss: 0.1053 - accuracy: 0.9678
Restored model, accuracy:96.78%
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# step2 创建模型
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
y=y_train,
epochs=1,
)
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
# step7 删除模型
del model # deletes the existing model
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2190 - accuracy: 0.9355
313/313 [==============================] - 1s 2ms/step - loss: 0.1026 - accuracy: 0.9679
train model, accuracy:96.79%
313/313 [==============================] - 1s 2ms/step - loss: 0.1026 - accuracy: 0.9679
Restored model, accuracy:96.79%
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# step2 创建模型
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# ————————————————————回调函数————————————————————————
logdir = './logs'
checkpoint_path = './checkpoint/min.{epoch:02d}-{val_loss:.2f}.ckpt'
def scheduler(epoch, lr):
if epoch < 10:
return lr
else:
return lr * tf.math.exp(-0.1)
callbacks = [
# tensorboard
tf.keras.callbacks.TensorBoard(log_dir=logdir, # 存放日志路径
histogram_freq=2), # 直方图频率
# 保存模型
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, # 模型保存路径
save_weights_only=True,# 只保存权重和偏执
verbose=1, # 以进度条方式展示
period=1 #每五个周期(epoch)存一个文件
),
# 终止训练的回调函数
tf.keras.callbacks.EarlyStopping(monitor='val_loss',# 监控对象
patience=3),# 允许周期
# 超过3个周期,val_loss升高就停止,防止过拟合
# 调整学习率
tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=0)]
# step4 模型训练
model.fit(x=x_train,
y=y_train,
epochs=1,
validation_split=0.2,
callbacks=callbacks,
)
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
# step7 删除模型
del model
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# step9 恢复权重
import os
logdir = './logs'
checkpoint_path = './checkpoint/min.{epoch:02d}-{val_loss:.2f}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)
latest = tf.train.latest_checkpoint(checkpoint_dir) # 查找最新的一条模型记录
model.load_weights(latest)
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))
WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of batches seen.
1500/1500 [==============================] - 9s 6ms/step - loss: 0.2449 - accuracy: 0.9279 - val_loss: 0.1252 - val_accuracy: 0.9632
Epoch 00001: saving model to ./checkpoint/min.01-0.13.ckpt
313/313 [==============================] - 1s 3ms/step - loss: 0.1215 - accuracy: 0.9617
train model, accuracy:96.17%
313/313 [==============================] - 1s 2ms/step - loss: 0.1215 - accuracy: 0.9617
Restored model, accuracy:96.17%
%load_ext tensorboard
%tensorboard --logdir logs
The tensorboard extension is already loaded. To reload it, use:
%reload_ext tensorboard
Reusing TensorBoard on port 6006 (pid 776), started 0:00:45 ago. (Use '!kill 776' to kill it.)