保存
1、定义变量
2、使用saver.save()方法保存
import tensorflow as tf
import numpy as np
W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b')
init = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess,"save/model.ckpt")
载入
1、定义变量
2、使用saver.restore()方法载入
import tensorflow as tf
import numpy as np
W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,"save/model.ckpt")
这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。
不需重新定义网络结构的方法
这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。
保存
### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')
w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
sess.run(train_op)
if step % 1000 == 0:
# 保存checkpoint, 同时也默认导出一个meta_graph
# graph名为'my-model-{global_step}.meta'.
saver.save(sess, 'my-model', global_step=step)
载入
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
new_saver.restore(sess, 'my-save-dir/my-model-10000')
# tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
y = tf.get_collection('pred_network')[0]
graph = tf.get_default_graph()
# 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
input_x = graph.get_operation_by_name('input_x').outputs[0]
keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]
# 使用y进行预测
sess.run(y, feed_dict={input_x:...., keep_prob:1.0})
上述装载:http://blog.csdn.net/thriving_fcl/article/details/71423039
下面源码仅供参考
import tensorflow as tf
import cv2
import alexnet as AN
import genarate_trainning_data as gtd
import create_tfRecord as tfRec
import numpy as np
import os
import time
from skimage import io as sio
from skimage.segmentation import slic
from skimage.segmentation import felzenszwalb as felseg#图分割函数
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('image_dir', './SED2/',
"""Directory where to detect saliency """)
tf.app.flags.DEFINE_string('model_dir', './model/',
"""Directory where to check point file """)
tf.app.flags.DEFINE_string('mean_image', './SED2/tfRecord_file/',
"""Directory where to demean """)
x = tf.placeholder(dtype=tf.float16, shape=[None, tfRec.IMAGE_HEIGHT, tfRec.IMAGE_WIDTH, 3])
def caculate_multi_level(images):
fparams = np.load('./seg_para.npy').item()
mean_image = np.load(FLAGS.mean_image +'mean_image.npy')
for curr_img in images:
img = sio.imread(FLAGS.image_dir+curr_img)
img = cv2.resize(img,(tfRec.IMAGE_HEIGHT,tfRec.IMAGE_WIDTH))
for g in range(15):
f_seg = felseg(img, sigma=fparams['sigma'][g], scale=np.float(fparams['scale'][g]),
min_size=np.int(fparams['min_size'][g]))
segments_temp = np.unique(f_seg)
for segment in segments_temp:
multi_scale_image = tfRec.merge_multiscale_image(img,)
def influence():
if 'out_result' not in os.listdir(FLAGS.image_dir):
os.mkdir(FLAGS.image_dir+'out_result')
gt,images = gtd.dirtomdfbatchmsra(FLAGS.image_dir)
with tf.Session() as sess:
meta = [fn for fn in os.listdir(FLAGS.model_dir) if fn.endswith('meta')]
saver = tf.train.import_meta_graph(FLAGS.model_dir+meta[0])
saver.restore(sess,tf.train.latest_checkpoint(FLAGS.model_dir))
predict = tf.get_collection('predict')[0]
graph = tf.get_default_graph()
input_x = graph.get_operation_by_name('input_x').outputs[0]
keep_pro = graph.get_operation_by_name('keep_pro').outputs[0]
fparams = np.load('./seg_para.npy').item()
mean_image = np.load(FLAGS.mean_image + 'mean_image.npy')
for index,curr_img in enumerate(images):
img = sio.imread(FLAGS.image_dir + curr_img)
img = cv2.resize(img, (tfRec.IMAGE_HEIGHT, tfRec.IMAGE_WIDTH))
gt_map = sio.imread(FLAGS.image_dir +gt[index])
gt_map = cv2.resize(gt_map, (tfRec.IMAGE_HEIGHT, tfRec.IMAGE_WIDTH))
cv2.imwrite((FLAGS.image_dir+'/out_result/' + curr_img[0:-4] + '-0' + '.jpg'), gt_map)
for g in range(15):
start_time = time.time()
f_seg = felseg(img, sigma=fparams['sigma'][g], scale=np.float(fparams['scale'][g]),
min_size=np.int(fparams['min_size'][g]))
segments_temp = np.unique(f_seg)
sp_batch = []
new_sp =[]
for segment in segments_temp:
multi_scale_image = tfRec.merge_multiscale_image(img, mean_image,f_seg,segment)
if multi_scale_image is None:
f_seg[f_seg==segment]=0
continue
new_sp.append(segment)
sp_batch.append(multi_scale_image)
sp_predict = sess.run(predict,feed_dict={input_x:sp_batch,keep_pro:1.0})
sp_label = np.argmax(sp_predict,0)
#recover a picture
#for (sp, sp_num) in (new_sp, sp_label):
for num in range(len(sp_label)):
if sp_label[num]:
f_seg[f_seg==new_sp[num]]= 255
else:
f_seg[f_seg == new_sp[num]] = 0
f_seg.astype(np.uint8)
duration=time.time()-start_time
#sio.imsave(('./out_result'+curr_img[0:-4]+'%d'+'.jpg')%g,f_seg)
cv2.imwrite((FLAGS.image_dir+'/out_result/'+curr_img[0:-4]+'-%d'+'.jpg')%(g+1),f_seg)
print((curr_img[0:-4]+'-%d'+'.jpg')%(g+1)+' has created...(%.3f sec)'%duration)
def main(_):
influence()
if __name__ == '__main__':
tf.app.run()