tensorflow模型保存与复用多种方式

 抄袭一段:checkpoint是一个内部事件,该事件激活后会触发数据库写进程将数据缓冲中的脏数据写到数据文件中。

checkpoint主要2个作用:

 保证数据库的一致性

缩短实例恢复时间

通俗的讲,checkpoint像word的自动保存一样。

tensorflow模型包含  meta图(网络结构图) 和 checkpoint文件(网络结构里的参数值,现已经被分拆为3个文件)

即总的文件包含目录为:

model.data-00000-of-00001  保存变量值

model.index                           保存 .data 和 .meta 文件对应关系

model.meta                           结构图

checkpoint                            文本文件,记录中间节点上保存的模型的名称

import tensorflow as tf
import os


W = tf.Variable(tf.zeros([2, 1]), name="weights")
b = tf.Variable(0., name="bias")


def inference(X):
    return tf.matmul(X, W) + b


def loss(X, Y):
    Y_predicted = inference(X)
    return tf.reduce_sum(tf.squared_difference(Y, Y_predicted))


def inputs():
    weight_age = [[84, 46], [73, 20], [65, 52], [70, 30], [76, 57], [69, 25], [63, 28], [72, 36], [79, 57], [75, 44],
                  [27, 24], [89, 31], [65, 52], [57, 23], [59, 60], [69, 48], [60, 34], [79, 51], [75, 50], [82, 34],
                  [59, 46], [67, 23], [85, 37], [55, 40], [63, 30]]
    blood_fat_content = [354, 190, 405, 263, 451, 302, 288, 385, 402, 365, 209, 290, 346, 254, 395, 434, 220, 374, 308,
                         220, 311, 181, 274, 303, 244]
    return tf.to_float(weight_age), tf.to_float(blood_fat_content)


def train(total_loss):
    learning_rate = 0.0000001
    return tf.train.GradientDescentOptimizer(learning_rate).minimize(total_loss)


def evaluate(sess, X, Y):
    print(sess.run(inference([[80., 25.]])))  #303
    print(sess.run(inference([[65., 25.]])))  #256

【1】模型训练:

with tf.Session() as sess:
    X, Y = inputs()
    #init = tf.global_variables_initializer()
    total_loss = loss(X, Y)
    train_op = train(total_loss)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    init = tf.global_variables_initializer()
    sess.run(init)
    training_steps = 10000
    saver = tf.train.Saver()
    for step in range(training_steps):
        sess.run(train_op)
        if step % 10 == 0:
            print("loss", sess.run(total_loss))

        if step % 1000 == 0:
            saver.save(sess, r"E:\tf_project\练习\model_save_dir\my-model", global_step=step)

    evaluate(sess, X, Y)
    saver.save(sess, r"E:\tf_project\练习\model_save_dir\my-model", global_step=training_steps)

    coord.request_stop()
    coord.join(threads)
    sess.close()

【2】模型重新加载

1、加载时间最近的模型,使用ckpt.model_checkpoint_path

with tf.Session() as sess:
    X, Y = inputs()
    total_loss = loss(X, Y)
    train_op = train(total_loss)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    initial_step = 0
    training_steps = 30000
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(os.path.dirname(r"E:\tf_project\练习\model_save_dir\my-model"))

    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print(ckpt.model_checkpoint_path)
        initial_step = int(ckpt.model_checkpoint_path.rsplit('-', 1)[1])

    for step in range(initial_step, training_steps):
        sess.run(train_op)
        if step % 10 == 0:
            print("loss", sess.run(total_loss))

        if step % 1000 == 0:
            saver.save(sess, r"E:\tf_project\练习\model_save_dir1\my-model", global_step=step)

    evaluate(sess, X, Y)
    saver.save(sess, r"E:\tf_project\练习\model_save_dir1\my-model", global_step=training_steps)

    coord.request_stop()
    coord.join(threads)
    sess.close()



output:

E:\tf_project\练习\model_save_dir\my-model-9000-20000
loss 5214449.5
loss 5214338.0
loss 5214226.0
loss 5214114.0
.
.
.
loss 5106910.0
loss 5106805.5
loss 5106701.0
loss 5106597.0
[[319.9712]]
[[270.7156]]

2、从时间最近的几个模型中选取一个或者多个模型加载

使用 ckpt.all_model_checkpoint_paths


with tf.Session() as sess:
    X, Y = inputs()
    total_loss = loss(X, Y)
    train_op = train(total_loss)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    init = tf.global_variables_initializer()
    sess.run(init)
    training_steps = 30000
    saver = tf.train.Saver()

    ckpt = tf.train.get_checkpoint_state(os.path.dirname(r"E:\tf_project\练习\model_save_dir\my-model"))
    path = ckpt.all_model_checkpoint_paths[1]
    print(ckpt.all_model_checkpoint_paths)
    if ckpt and path:
        saver.restore(sess, path)
        initial_step = int(ckpt.model_checkpoint_path.rsplit('-', 1)[1])

    for step in range(training_steps):
        sess.run(train_op)
        if step % 10 == 0:
            print("loss", sess.run(total_loss))

        if step % 1000 == 0:
            saver.save(sess, r"E:\tf_project\练习\model_save_dir1\my-model", global_step=step)

    evaluate(sess, X, Y)
    saver.save(sess, r"E:\tf_project\练习\model_save_dir1\my-model", global_step=training_steps)

    coord.request_stop()
    coord.join(threads)
    sess.close()

output:

['E:\\tf_project\\练习\\model_save_dir\\my-model-9000-16000',
 'E:\\tf_project\\练习\\model_save_dir\\my-model-9000-17000',
 'E:\\tf_project\\练习\\model_save_dir\\my-model-9000-18000', 
 'E:\\tf_project\\练习\\model_save_dir\\my-model-9000-19000',
 'E:\\tf_project\\练习\\model_save_dir\\my-model-9000-20000']
loss 5217809.0
loss 5217696.5
loss 5217585.0
loss 5217473.0
loss 5217361.0
.
.
.
loss 5110039.5
loss 5109935.5
loss 5109830.5
loss 5109727.0
[[319.98105]]
[[270.67313]]

3、使用结构图加载 tf.train.import_meta_graph

with tf.Session() as sess:
    X, Y = inputs()
    total_loss = loss(X, Y)
    train_op = train(total_loss)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    init = tf.global_variables_initializer()
    sess.run(init)
    training_steps = 10000

    saver = tf.train.import_meta_graph(r"E:\tf_project\练习\model_save_dir\my-model-20000.meta")
    saver.restore(sess, tf.train.latest_checkpoint(r"E:\tf_project\练习\model_save_dir"))
    print(tf.train.latest_checkpoint(r"E:\tf_project\练习\model_save_dir"))

    for step in range(training_steps):
        sess.run(train_op)
        if step % 10 == 0:
            print("loss", sess.run(total_loss))

        if step % 1000 == 0:
            saver.save(sess, r"E:\tf_project\练习\model_save_dir1\my-model", global_step=step)

    evaluate(sess, X, Y)
    saver.save(sess, r"E:\tf_project\练习\model_save_dir1\my-model", global_step=training_steps)

    coord.request_stop()
    coord.join(threads)
    sess.close()

output:

E:\tf_project\练习\model_save_dir\my-model-9000-20000
loss 7608772.0
loss 5352849.5
loss 5350043.5
loss 5347919.0
loss 5346300.5
.
.
.
loss 5226120.5
loss 5226008.0
loss 5225895.5
loss 5225782.0
[[320.33838]]
[[269.12772]]

4、通用加载方式,使用 saver.restore

这里可以指定从哪一个模型进行加载

with tf.Session() as sess:
    CHECKPOINT_PATH = r"E:\tf_project\练习\model_save_dir\my-model-9000"
    saver = tf.train.Saver()
    saver.restore(sess, CHECKPOINT_PATH)
    print(CHECKPOINT_PATH)
    X, Y = inputs()
    total_loss = loss(X, Y)
    train_op = train(total_loss)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    initial_step = 0
    training_steps = 10000

    for step in range(initial_step, training_steps):
        sess.run(train_op)
        if step % 10 == 0:
            print("loss", sess.run(total_loss))

        if step % 1000 == 0:
            saver.save(sess, CHECKPOINT_PATH, global_step=step)

    evaluate(sess, X, Y)
    saver.save(sess, CHECKPOINT_PATH, global_step=training_steps)

    coord.request_stop()
    coord.join(threads)
    sess.close()

output:

E:\tf_project\练习\model_save_dir\my-model-9000
loss 5236970.0
loss 5236958.5
loss 5236947.5
loss 5236935.5
.
.
.
loss 5225708.5
loss 5225696.5
loss 5225685.5
[[320.33884]]
[[269.1281]]

 

你可能感兴趣的:(人工智能,调bug,tensorflow学习)