Tensorflow保存和恢复模型

方法

saver = tf.train.Saver()
获得一个文件句柄,将训练中的某一个快照状态保存到文件中去

saver.save(sess, os.path.join(model_dir, ‘ckp-%05d’%(i+1)))
将训练好的模型保存到文件中
sess:session会话
参数2:模型保存路径

saver.restore(sess, model_path)
从文件中恢复模型
sess:session会话
model_path:需要

演示

with tf.Session() as sess:
    sess.run( init ) # 注意: 这一步必须要有!!

    # 打开一个writer,向writer中写数据
    train_writer = tf.summary.FileWriter(train_log_dir, sess.graph) # 参数2:显示计算图
    test_writer = tf.summary.FileWriter(test_log_dir)
    
    fixed_test_batch_data, fixed_test_batch_labels = test_data.next_batch(batch_size)

    if os.path.exists(model_path + '.index'):
        saver.restore(sess, model_path)
        print('model restored from %s' % model_path)
    else:
        print('model %s dose not exist' % model_path)

    # 开始训练
    for i in range( train_steps ):
        # 得到batch
        batch_data, batch_labels = train_data.next_batch( batch_size )
        
        eval_ops = [loss, accuracy, train_op]
        should_output_summary = ((i+1) % output_summary_every_steps == 0)

        if should_output_summary:
            eval_ops.append(merged_summary)

        # 获得 损失值, 准确率
        eval_val_results = sess.run( eval_ops, feed_dict={x:batch_data, y:batch_labels} )
        loss_val, acc_val = eval_val_results[0:2]

        if should_output_summary:
            train_summary_str = eval_val_results[-1]
            train_writer.add_summary(train_summary_str, i+1)
            test_summary_str = sess.run([merged_summary_test],
                                        feed_dict = {x: fixed_test_batch_data,y: fixed_test_batch_labels} )[0]
            test_writer.add_summary(test_summary_str, i+1)

        # 每 500 次 输出一条信息
        if ( i+1 ) % 500 == 0:
            print('[Train] Step: %d, loss: %4.5f, acc: %4.5f' % ( i+1, loss_val, acc_val ))
        # 每 5000 次 进行一次 测试
        if ( i+1 ) % 5000 == 0:
            # 获取数据集,但不随机
            test_data = CifarData( test_filename, False )
            all_test_acc_val = []
            for j in range( test_steps ):
                test_batch_data, test_batch_labels = test_data.next_batch( batch_size )
                test_acc_val = sess.run( [accuracy], feed_dict={ x:test_batch_data, y:test_batch_labels } )
                all_test_acc_val.append( test_acc_val )
            test_acc = np.mean( all_test_acc_val )

            print('[Test ] Step: %d, acc: %4.5f' % ( (i+1), test_acc ))

        if (i+1) % output_model_every_steps == 0:
            # saver 机制,保存最近的5个模型
            saver.save(sess, os.path.join(model_dir, 'ckp-%05d'%(i+1)))
            print('model saved to ckp-%05d' % (i+1))

你可能感兴趣的:(tensorflow,Tensorflow学习笔记)