定义回调函数tf.keras.callbacks.ModelCheckpoint(),在model.fit()中加入该回调函数,将在model训练时自动调用回调函数保存训练过程记录和模型结构以及参数权重,这些内容将保存在一个.ckpt文件中。、
import os
import tensorflow as tf
from tensorflow import keras
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
# Define a simple sequential model
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
batch_size,num_classes = train_labels.shape[0], 10
model.compile(optimizer='adam',
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
return model
# Create a basic model instance
model = create_model()
# Display the model's architecture
model.summary()
checkpoint_path = "tf_ckpt_logs/model.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
verbose=1)
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=2,
validation_data=(test_images, test_labels),
callbacks=[cp_callback]) # Pass callback to training
# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
# Create a basic model instance
model = create_model()
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))
# Loads the weights
model.load_weights(checkpoint_path)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 512) 401920
_________________________________________________________________
dropout_1 (Dropout) (None, 512) 0
_________________________________________________________________
dense_2 (Dense) (None, 10) 5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
Train on 1000 samples, validate on 1000 samples
Epoch 1/2
2020-04-08 13:40:08.180275: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\platform\cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
2020-04-08 13:40:08.925325: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:1356] Found device 0 with properties:
name: GeForce GTX 950M major: 5 minor: 0 memoryClockRate(GHz): 1.124
pciBusID: 0000:01:00.0
totalMemory: 2.00GiB freeMemory: 1.64GiB
2020-04-08 13:40:08.925690: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:1435] Adding visible gpu devices: 0
2020-04-08 13:40:09.688851: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:923] Device interconnect StreamExecutor with strength 1 edge matrix:
2020-04-08 13:40:09.689071: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:929] 0
2020-04-08 13:40:09.689210: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:942] 0: N
2020-04-08 13:40:09.689480: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:1053] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 1403 MB memory) -> physical GPU (device: 0, name: GeForce GTX 950M, pci bus id: 0000:01:00.0, compute capability: 5.0)
32/1000 [..............................] - ETA: 1:03 - loss: 10.6530 - acc: 0.0938
384/1000 [==========>...................] - ETA: 3s - loss: 5.0778 - acc: 0.1823
736/1000 [=====================>........] - ETA: 0s - loss: 3.7411 - acc: 0.1793
1000/1000 [==============================] - 2s 2ms/step - loss: 3.3614 - acc: 0.1810 - val_loss: 2.3026 - val_acc: 0.1960
Epoch 00001: saving model to tf_ckpt_logs/model.ckpt
Epoch 2/2
32/1000 [..............................] - ETA: 0s - loss: 2.3026 - acc: 0.0938
384/1000 [==========>...................] - ETA: 0s - loss: 2.3026 - acc: 0.1380
736/1000 [=====================>........] - ETA: 0s - loss: 2.2901 - acc: 0.1603
1000/1000 [==============================] - 0s 194us/step - loss: 2.2888 - acc: 0.1590 - val_loss: 2.3026 - val_acc: 0.1830
Epoch 00002: saving model to tf_ckpt_logs/model.ckpt
Untrained model, accuracy: 12.30%
Restored model, accuracy: 18.30%
Process finished with exit code 0
可以看到模型保存到model.ckpt文件中了。
训练循环中,定期调用 saver.save() 方法,向文件夹中写入包含当前模型中所有可训练变量的 checkpoint 文件
tf.train.Saver().save(sess, 'ckpts/'),将在ckpts/ 路径下保存四个文件:
checkpoint:文本文件 vim 可查看内容 记录保存了那些checkpoint
以下三个文件组成一个checkpoint:
model.ckpt.data-00000-of-00001: 某个ckpt的数据文件
model.ckpt.index :某个ckpt的index文件 二进制 或者其他格式 不可直接查看
model.ckpt.meta:某个ckpt的meta数据 二进制 或者其他格式 不可直接查看
示例:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
h = 1
v = -2
#prepare data
x_train = np.linspace(-2, 4, 201) #x样本
noise = np.random.randn(*x_train.shape) * 0.4 #噪音
y_train = (x_train - h) ** 2 + v + noise #y样本
n = x_train.shape[0]
x_train = np.reshape(x_train, (n, 1)) #重塑
y_train = np.reshape(y_train, (n, 1))
#画出产生的数据的形状
'''
plt.rcParams['figure.figsize'] = (10, 6)
plt.scatter(x_train, y_train)
plt.xlabel('x_train')
plt.ylabel('y_train')
plt.show()
'''
#create variable
X = tf.placeholder(tf.float32, [1]) #两个占位符,x和y
Y = tf.placeholder(tf.float32, [1])
h_est = tf.Variable(tf.random_uniform([1], -1, 1)) #定义需要训练的参数,在saver之前定义
v_est = tf.Variable(tf.random_uniform([1], -1, 1))
saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3) #保存模型参数的saver
value = (X - h_est) ** 2 + v_est #拟合的曲线
loss = tf.reduce_mean(tf.square(value - Y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.01).minimize(loss)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for epoch in range(10): #100个epoch
for (x, y) in zip(x_train, y_train):
sess.run(optimizer, feed_dict={X: x, Y: y})
#保存checkpoint
saver.save(sess, 'model/model.ckpt', global_step=epoch)
#saver the final model
saver.save(sess, 'model/model.ckpt') #最后一个epoch对应的checkpoint
h_ = sess.run(h_est)
v_ = sess.run(v_est)
print(h_, v_)
[1.0023267] [-2.0263677]
可以看到模型保存到了:
利用Saver.restore方法。可以加载固定参数或者所有参数。
saver.restore(sess,model_path)
参考:https://www.jianshu.com/p/0bcaab1e7cda
https://blog.csdn.net/index20001/article/details/74322198