TensorFlow直接使用ckpt模型predict不用restore

# -*- coding: utf-8 -*-

# from util import *
import cv2
import numpy as np
import tensorflow as tf
# from tensorflow.python.framework import graph_util
import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
image_path = './8760.pgm'

input_checkpoint = './model/xu_spatial_model_1340.ckpt'

sess = tf.Session()
saver = tf.train.import_meta_graph(input_checkpoint + '.meta')
saver.restore(sess, input_checkpoint)

# input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
input_image_tensor = sess.graph.get_tensor_by_name("coef_input:0")
is_training = sess.graph.get_tensor_by_name('is_training:0')
batch_size = sess.graph.get_tensor_by_name('batch_size:0')
# 定义输出的张量名称
output_tensor_name = sess.graph.get_tensor_by_name("xuNet/logits:0")  # xuNet/Logits/logits
image = cv2.imread(image_path, 0)
# 读取测试图片
out = sess.run(output_tensor_name, feed_dict={input_image_tensor: np.reshape(image, (1, 512, 512, 1)),
                                              is_training: False,
                                              batch_size: 1})
print(out)

ckpt模型中的所有节点名称,可以这样查看

[n.name for n in tf.get_default_graph().as_graph_def().node]

你可能感兴趣的:(计算机视觉,解决方案,tensorflow)