tensorflow保存模型,保存训练过程中的checkpoint

1、使用tf.keras.callbacks.ModelCheckpoint()

定义回调函数tf.keras.callbacks.ModelCheckpoint(),在model.fit()中加入该回调函数,将在model训练时自动调用回调函数保存训练过程记录和模型结构以及参数权重,这些内容将保存在一个.ckpt文件中。、

示例:

import os
import tensorflow as tf
from tensorflow import keras

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

# Define a simple sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])
  batch_size,num_classes = train_labels.shape[0], 10
  model.compile(optimizer='adam',
                loss=keras.losses.sparse_categorical_crossentropy,
                metrics=['accuracy'])
  return model

# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()

checkpoint_path = "tf_ckpt_logs/model.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images,
          train_labels,
          epochs=2,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])  # Pass callback to training

# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.

# Create a basic model instance
model = create_model()

# Evaluate the model
loss, acc = model.evaluate(test_images,  test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))

# Loads the weights
model.load_weights(checkpoint_path)

# Re-evaluate the model
loss, acc = model.evaluate(test_images,  test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

运行结果:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 512)               401920    
_________________________________________________________________
dropout_1 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
Train on 1000 samples, validate on 1000 samples
Epoch 1/2
2020-04-08 13:40:08.180275: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\platform\cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
2020-04-08 13:40:08.925325: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:1356] Found device 0 with properties: 
name: GeForce GTX 950M major: 5 minor: 0 memoryClockRate(GHz): 1.124
pciBusID: 0000:01:00.0
totalMemory: 2.00GiB freeMemory: 1.64GiB
2020-04-08 13:40:08.925690: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:1435] Adding visible gpu devices: 0
2020-04-08 13:40:09.688851: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:923] Device interconnect StreamExecutor with strength 1 edge matrix:
2020-04-08 13:40:09.689071: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:929]      0 
2020-04-08 13:40:09.689210: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:942] 0:   N 
2020-04-08 13:40:09.689480: I C:\users\nwani\_bazel_nwani\mmtm6wb6\execroot\org_tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:1053] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 1403 MB memory) -> physical GPU (device: 0, name: GeForce GTX 950M, pci bus id: 0000:01:00.0, compute capability: 5.0)

  32/1000 [..............................] - ETA: 1:03 - loss: 10.6530 - acc: 0.0938
 384/1000 [==========>...................] - ETA: 3s - loss: 5.0778 - acc: 0.1823   
 736/1000 [=====================>........] - ETA: 0s - loss: 3.7411 - acc: 0.1793
1000/1000 [==============================] - 2s 2ms/step - loss: 3.3614 - acc: 0.1810 - val_loss: 2.3026 - val_acc: 0.1960

Epoch 00001: saving model to tf_ckpt_logs/model.ckpt
Epoch 2/2

  32/1000 [..............................] - ETA: 0s - loss: 2.3026 - acc: 0.0938
 384/1000 [==========>...................] - ETA: 0s - loss: 2.3026 - acc: 0.1380
 736/1000 [=====================>........] - ETA: 0s - loss: 2.2901 - acc: 0.1603
1000/1000 [==============================] - 0s 194us/step - loss: 2.2888 - acc: 0.1590 - val_loss: 2.3026 - val_acc: 0.1830

Epoch 00002: saving model to tf_ckpt_logs/model.ckpt
Untrained model, accuracy: 12.30%
Restored model, accuracy: 18.30%

Process finished with exit code 0

可以看到模型保存到model.ckpt文件中了。

2、使用tf.train.Saver()

2.1保存模型

训练循环中,定期调用 saver.save() 方法,向文件夹中写入包含当前模型中所有可训练变量的 checkpoint 文件

(1)Saver()保存的结果:

tf.train.Saver().save(sess, 'ckpts/'),将在ckpts/ 路径下保存四个文件:

checkpoint:文本文件 vim 可查看内容 记录保存了那些checkpoint

以下三个文件组成一个checkpoint:

model.ckpt.data-00000-of-00001: 某个ckpt的数据文件

model.ckpt.index :某个ckpt的index文件 二进制 或者其他格式 不可直接查看

 model.ckpt.meta:某个ckpt的meta数据  二进制 或者其他格式 不可直接查看

(2)使用注意:

  1. saver 的操作必须在 sess 建立后进行。
  2. model.ckpt 必须存在给定文件夹中,如'model/model.ckpt' 这里至少要有一层文件夹,否则无法保存。
  3. 恢复模型时同保存时一样,如 ‘model/model.ckpt’,和那3个文件名都不一样。

示例:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

h = 1
v = -2

#prepare data
x_train = np.linspace(-2, 4, 201)                        #x样本
noise = np.random.randn(*x_train.shape) * 0.4            #噪音
y_train = (x_train - h) ** 2 + v + noise                 #y样本

n = x_train.shape[0]

x_train = np.reshape(x_train, (n, 1))                    #重塑
y_train = np.reshape(y_train, (n, 1))

#画出产生的数据的形状
'''
plt.rcParams['figure.figsize'] = (10, 6)
plt.scatter(x_train, y_train)
plt.xlabel('x_train')
plt.ylabel('y_train')
plt.show()
'''
#create variable
X = tf.placeholder(tf.float32, [1])                      #两个占位符,x和y
Y = tf.placeholder(tf.float32, [1])

h_est = tf.Variable(tf.random_uniform([1], -1, 1))       #定义需要训练的参数,在saver之前定义
v_est = tf.Variable(tf.random_uniform([1], -1, 1))

saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)                                 #保存模型参数的saver

value = (X - h_est) ** 2 + v_est                         #拟合的曲线

loss = tf.reduce_mean(tf.square(value - Y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.01).minimize(loss)


init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    for epoch in range(10):                             #100个epoch
        for (x, y) in zip(x_train, y_train):

            sess.run(optimizer, feed_dict={X: x, Y: y})
        #保存checkpoint
        saver.save(sess, 'model/model.ckpt', global_step=epoch)

    #saver the final model
    saver.save(sess, 'model/model.ckpt')                    #最后一个epoch对应的checkpoint
    h_ = sess.run(h_est)
    v_ = sess.run(v_est)

    print(h_, v_)

运行结果:

[1.0023267] [-2.0263677]

可以看到模型保存到了:

tensorflow保存模型,保存训练过程中的checkpoint_第1张图片

2.2 恢复模型

利用Saver.restore方法。可以加载固定参数或者所有参数。

    saver.restore(sess,model_path)

参考:https://www.jianshu.com/p/0bcaab1e7cda

https://blog.csdn.net/index20001/article/details/74322198

你可能感兴趣的:(tensoflow)