读取pb模型进行预测

本程序是解析一个tfrecord文件数据,然后调用训练好的pb模型文件去预测这些数据的类别,返回一个列表。

 之前的训练程序和需要的数据到可以到这儿找:https://blog.csdn.net/macunshi/article/details/86220389

如果只想单独运行这一个程序,那么在此提供一个本地训练的模型和一个数据文件。

 

 

#encoding: utf-8
# prediction.py
# Tensorflow 1.10.0

import os
import numpy as np
from PIL import Image
import tensorflow as tf
  
def parse(test_data_filename):
    print ("数据解析中...")
    if not os.path.exists(os.getcwd()+"/test_data"):
        os.makedirs('test_data')
    
    reader=tf.TFRecordReader()
    filename_queue=tf.train.string_input_producer([test_data_filename])

    _,serialized_example=reader.read(filename_queue)

    features=tf.parse_single_example(serialized_example,features={
    'data' :tf.FixedLenFeature([65536],tf.float32),
    'label' :tf.FixedLenFeature([1],tf.int64),
    'id' :tf.FixedLenFeature([1],tf.int64)})

    image_tensor=features['data']
    ID_tensor=features['id']
    label_tensor=features['label']

    with tf.Session() as sess:
        coord=tf.train.Coordinator()
        threads=tf.train.start_queue_runners(sess=sess,coord=coord)
        filenames=[]
        for i in range(400):
            im,label,ID=sess.run([image_tensor,label_tensor,ID_tensor])
            im=im.reshape(256,256)
            im = (im+1)*255/2
            new_im = Image.fromarray(np.uint8(im))
            x=str(ID)
            y=x.replace("[","")
            y=y.replace("]","")
            new_im.save("test_data/"+str(y)+".jpg")
            filenames.append(str(y)+".jpg")
    return filenames

def model_test(test_data_filename):
    filenames=parse(test_data_filename)
    with tf.gfile.FastGFile('model/my_train.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
    predictions=[]
    with tf.Session() as sess:
        softmax_tensor = sess.graph.get_tensor_by_name('evaluation/out_prob:0')
        i=0       
        for file in filenames:
            image_data = tf.gfile.FastGFile(os.path.join("test_data/", file), 'rb').read()
            prediction = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})
            predictions.extend(prediction+1)
            i=i+1
            print ("第"+str(i)+"张分类完毕")
    return predictions
            
def main():  
    label=model_test("TFcodeX_1.tfrecord")# 替换为TFcodeX_test.tfrecord
    print ("\n预测结果向量:\n",label)

main()

 

你可能感兴趣的:(深度学习)