实现断点续训
输入真实图片,输出预测结果
实现断点续训,在 mnist_backward.py 里加入三行代码即可:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 实现断点续训 ----------------------------------------
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
# ----------------------------------------------------
for i in range(STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_v, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
if i % 1000 == 0:
print('After %d training steps, loss on training batch is %g.' % (step, loss_v))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
(1) tf.train.get_checkpoint_state(checkpoint_dir,latest_filename=None)
该函数表示如果断点文件夹中包含有效断点状态文件,则返回该文件。
参数说明:
checkpoint_dir:表示存储断点文件的目录
latest_filename=None:断点文件的可选名称,默认为“checkpoint”
(2) saver.restore(sess, ckpt.model_checkpoint_path)
该函数表示恢复当前会话,将 ckpt 中的值赋给 w 和 b。
参数说明:
sess:表示当前会话,之前保存的结果将被加载入这个会话
ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看 checkpoint 文件,看看最新的是谁,叫做什么。
输入真实图片,输出预测结果:
mnist_forward.py 和 mnist_backward.py 、mnist_test.py不变,增加一个mnist_app.py
模型的要求是黑(0)底白(255)字,但输入的图是白底黑字,所以需要对每个像素点的值改为 255 减去原值以得到互补的反色。
# coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
import mnist_forward
import mnist_backward
def restore_model(testPicArr):
# 创建一个默认图,在图中执行相应操作
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y = mnist_forward.forward(x, None)
preValue = tf.argmax(y, 1)
ema = tf.train.ExponentialMovingAverage(mnist_backward.EMA_DECAY)
ema_restore = ema.variables_to_restore()
saver = tf.train.Saver(ema_restore)
with tf.Session() as sess:
# 通过checkpoint文件定位到最新保存的模型
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
preValue_val = sess.run(preValue, feed_dict={x:testPicArr})
return preValue_val
else:
print('No checkpoint file found.')
return -1
# 输入图片预处理函数
def pre_pic(picName):
img = Image.open(picName)
# 用消除锯齿的方式 resize
reIm = img.resize((28, 28), Image.ANTIALIAS)
# 转变为灰度图
im_arr = np.array(reIm.convert('L'))
# 设定合理的阈值,对图片做二值化处理(这样以滤掉噪声,另外调试中可适当调节阈值)
threshold = 50
for i in range(28):
for j in range(28):
im_arr[i][j] = 255 - im_arr[i][j]
if (im_arr[i][j] < threshold):
im_arr[i][j] = 0
else:
im_arr[i][j] = 255
nm_arr = im_arr.reshape([1, 784])
nm_arr = nm_arr.astype(np.float32)
im_ready = np.multiply(nm_arr, 1.0/255.0)
return im_ready
def application():
testNum = input('input the number of test pictures:')
for i in range(int(testNum)):
testPic = input('input the path of test picture:')
# 对手写数字图片做预处理
testPicArr = pre_pic(testPic)
# 将符合神经网络输入要求的图片喂给复现的神经网络模型,输出预测值
preValue = restore_model(testPicArr)
print('The prediction number is: ', preValue)
if __name__ == '__main__':
application()
如果没有with tf.Graph().as_default() as g:
会报错:
input the number of test pictures:10
input the path of test picture:pic/0.png
2018-07-19 10:19:56.079652: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
2018-07-19 10:19:56.317172: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:892] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2018-07-19 10:19:56.317466: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1030] Found device 0 with properties:
name: GeForce 940MX major: 5 minor: 0 memoryClockRate(GHz): 1.2415
pciBusID: 0000:01:00.0
totalMemory: 1.96GiB freeMemory: 1.94GiB
2018-07-19 10:19:56.317483: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce 940MX, pci bus id: 0000:01:00.0, compute capability: 5.0)
The prediction number is: [0]
input the path of test picture:pic/1.png
2018-07-19 10:20:12.682054: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce 940MX, pci bus id: 0000:01:00.0, compute capability: 5.0)
2018-07-19 10:20:12.689655: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Variable_7/ExponentialMovingAverage not found in checkpoint
2018-07-19 10:20:12.690771: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Variable_4/ExponentialMovingAverage not found in checkpoint
2018-07-19 10:20:12.691076: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Variable_5/ExponentialMovingAverage not found in checkpoint
2018-07-19 10:20:12.691219: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Variable_6/ExponentialMovingAverage not found in checkpoint
Traceback (most recent call last):
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1323, in _do_call
return fn(*args)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1302, in _run_fn
status, run_metadata)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.NotFoundError: Key Variable_7/ExponentialMovingAverage not found in checkpoint
[[Node: save_1/RestoreV2_7 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_7/tensor_names, save_1/RestoreV2_7/shape_and_slices)]]
[[Node: save_1/RestoreV2_6/_3 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_22_save_1/RestoreV2_6", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "mnist_application.py", line 58, in
application()
File "mnist_application.py", line 53, in application
preValue = restore_model(testPicArr)
File "mnist_application.py", line 21, in restore_model
saver.restore(sess, ckpt.model_checkpoint_path)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1666, in restore
{self.saver_def.filename_tensor_name: save_path})
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 889, in run
run_metadata_ptr)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1120, in _run
feed_dict_tensor, options, run_metadata)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1317, in _do_run
options, run_metadata)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1336, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Key Variable_7/ExponentialMovingAverage not found in checkpoint
[[Node: save_1/RestoreV2_7 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_7/tensor_names, save_1/RestoreV2_7/shape_and_slices)]]
[[Node: save_1/RestoreV2_6/_3 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_22_save_1/RestoreV2_6", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
Caused by op 'save_1/RestoreV2_7', defined at:
File "mnist_application.py", line 58, in
application()
File "mnist_application.py", line 53, in application
preValue = restore_model(testPicArr)
File "mnist_application.py", line 14, in restore_model
saver = tf.train.Saver(ema_restore)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1218, in __init__
self.build()
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1227, in build
self._build(self._filename, build_save=True, build_restore=True)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1263, in _build
build_save=build_save, build_restore=build_restore)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 751, in _build_internal
restore_sequentially, reshape)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 427, in _AddRestoreOps
tensors = self.restore_op(filename_tensor, saveable, preferred_shard)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 267, in restore_op
[spec.tensor.dtype])[0])
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1021, in restore_v2
shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
op_def=op_def)
File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
NotFoundError (see above for traceback): Key Variable_7/ExponentialMovingAverage not found in checkpoint
[[Node: save_1/RestoreV2_7 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_7/tensor_names, save_1/RestoreV2_7/shape_and_slices)]]
[[Node: save_1/RestoreV2_6/_3 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_22_save_1/RestoreV2_6", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]