》首先模型训练结束后,我们会得到关于【3】【4】检查点的四个文件:.data, .index, .meta, checkpoint;
其中.meta是图结构,也就是神经网络的结构,在训练过程中图结构水不不变的,保存一次
即可。实现:saver = tf.train.Saver(), saver.save(less, ‘model-name’,write_meta_graph=False);
.data是模型权重,偏置,操作等数值。
.index是主要保存.data数据中对应名字。
即.index 与 .data构成了键值对。
checkpoint 保存是在训练过程中所有中间节点上保存模型的名称。第一行保存最后一次保存的模型的名称。
》主要利用.data, .index, .meta文件提取特征图。基于【2】,过程如下
for op in graph.get_operations():
print(op.name)
—————————————————————————————————————————————————
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import os
import datetime
import SimpleITK as sitk
import numpy as np
image_path = "D:\KITS\Code\\vnet-tensorflow-master\extract_feature_map\TestPatch\image.nii.gz"
save_path = "D:\KITS\Code\\vnet-tensorflow-master\extract_feature_map\data_preprocess_vnet"
# select gpu devices
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # e.g. "0,1,2", "0,2"
# tensorflow app flags
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('model_path','D:\KITS\same_space\data_preprocess\modelVNet2\\tmp\ckpt\checkpoint-24336.meta',
"""Path to saved models""")
tf.app.flags.DEFINE_string('checkpoint_path','D:\KITS\same_space\data_preprocess\modelVNet2\\tmp\ckpt\checkpoint-24336',
"""Directory of saved checkpoints""")
def trucateImage(image_np, low_value=-79, high_value=304):
image_np = image_np.clip(min=low_value, max=high_value)
image_np = image_np - 101
image_np = np.true_divide(image_np, 76.9)
return image_np
def evaluate():
"""evaluate the vnet model by stepwise moving along the 3D image"""
# restore model grpah
tf.reset_default_graph()
# 从.meta文件加载模型
imported_meta = tf.train.import_meta_graph(FLAGS.model_path)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
graph = tf.get_default_graph()
# set input
input_ = graph.get_tensor_by_name('images_placeholder:0')
print('input shape:',input_.shape)
# set output
output_ = graph.get_tensor_by_name('vnet/vnet/encoder/level_2/conv_2/result:0')
with tf.Session(config=config) as sess:
print("{}: Start evaluation...".format(datetime.datetime.now()))
# 从checkpoint-24336(.index, .data)文件加载权重
imported_meta.restore(sess, FLAGS.checkpoint_path)
print("{}: Restore checkpoint success".format(datetime.datetime.now()))
image = sitk.ReadImage(image_path)
image_np = sitk.GetArrayFromImage(image).astype(np.float32)
image_np = trucateImage(image_np)
patch = np.expand_dims(image_np, axis=0)
patch = np.expand_dims(patch, axis=-1)
patch_pd = sess.run(output_,feed_dict={input_: patch})
np.save(os.path.join(save_path, "encoder_l2_conv_2.npy"), patch_pd)
print("Finish inferencing ",patch_pd.shape)
# for op in graph.get_operations():
# print(op.name)
def main():
evaluate()
if __name__=='__main__':
main()
—————————————————————————————————————————————————
参考文献:
【1】https://blog.csdn.net/qq_41185868/article/details/82903223
【2】https://murphypei.github.io/blog/2019/08/tensorflow-show-layer.html
【3】https://blog.csdn.net/u014090429/article/details/93487539
【4】https://www.cnblogs.com/azheng333/p/6972619.html
【5】https://machinelearningmastery.com/how-to-visualize-filters-and-feature-maps-in-convolutional-neural-networks/
—————————————————————————————————————————————————
sess.run(['predicted_label/prediction:0','softmax/softmax:0'], feed_dict={
'images_placeholder:0': batch,
为什么传入参数是这种形式:———:0, 后面还跟着一个:0? 0 表示 batch 中的第一个,如果 batch 是 1 就是全部结果了