tensorflow实战google深度学习框架阅读笔记——保存,读取model(ckpt文件)

 

最近在阅读《tensorflow实战google深度学习框架》,对里面讲到的内容,重点部分做下摘抄和笔记,以备后面查阅。部分内容为本人个人理解,如果错误,请指正,如果侵权,请联系删除,谢谢。转载请注明出处,谢谢。

 

将模型保存为ckpt文件

 

 

    首先,创建一个saver对象:saver=tf.train.Saver(max_to_keep = 5) 注意,这句话要写在创建graph的代码中,在图创建完成并且初始化variable后,再调用。max_to_keep代表需要保存的模型的个数,默认为5,如果只需要保存最新的模型,设置为1即可

    然后,saver.save(sess,'ckpt/mnist.ckpt',global_step=step,write_meta_graph=False)  ,其中,第一个参数为sess,第二个参数设定保存的路径和名字,这里可以用os.path.join(opts.save_path,"model.ckpt")来组合,第一个为路径,第二个为名字。第三个参数将训练的次数作为后缀加入到模型名字中去(改参数可以不加),第四个参数为是否保存.meta,meta中保存的是模型的图,不需要每次都保存,所以可以设置为false,默认为true,这里可以这么写:

saver.save(sess, 'model/model.ckpt', global_step=step, write_meta_graph=False)

if not os.path.exists(' model/model.meta'):

   saver.export_meta_graph(metagraph_filename)

 

从ckpt文件中读取参数,恢复模型

 

有两种方法:

方法一:不恢复模型,直接从ckpt文件中读取参数:

 

tensorflow实战google深度学习框架阅读笔记——保存,读取model(ckpt文件)_第1张图片

方法二:恢复模型,然后调用sess.run获取参数:

tensorflow实战google深度学习框架阅读笔记——保存,读取model(ckpt文件)_第2张图片

 

注意:这里面,tf.train.import_meta_graph是从meta文件中读取图,如果图中存在自定义的op,是行不通的,所以【4】这个操作是加载自定义op

另外,这里参数的名字可以通过方式一获取,然后这里的要用  参数名:0 的方式获取参数。

 

利用tensorboard可视化模型图

 

tf.reset_default_graph()
with tf.Session() as sess:
  with open('F:\\jupyter\\data\\xxxx\\xxxx.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def)

  LOGDIR='F:\\jupyter\\data\\xxx'
  train_writer = tf.summary.FileWriter(LOGDIR)
  train_writer.add_graph(sess.graph)
  train_writer.close()

运行这段代码后,会将图结构以日志文件的形式保存到给定的路径下。 

然后在终端启动tensorboard,输入log文件的地址,tensorboard --logdir=XXXX , 就可以在tensorboard中查看模型的图结构了,

tensorboard的默认端口是6006,浏览器127.0.0.1:6006即可访问。

上面的代码是将保存好的pb文件,先解析pb文件,然后利用tf.summary.FileWriter保存为log,如果不是pb文件,而是在程序中,直接将图保存即可:

  LOGDIR='F:\\jupyter\\data\\xxx'
  train_writer = tf.summary.FileWriter(LOGDIR)
  train_writer.add_graph(sess.graph)

 

 

 

 

你可能感兴趣的:(tensorflow实战google深度学习框架阅读笔记——保存,读取model(ckpt文件))