tensorflow(九):自定义模型

一、数据准备

tensorflow(九):自定义模型_第1张图片

tensorflow(九):自定义模型_第2张图片

  • animal:http://www.robots.ox.ac.uk/~vgg/data/pets/ (images.tar.gz,~765M)
  • flower:http://www.robots.ox.ac.uk/~vgg/data/flowers/ (17flowers.tgz,~58.8M)
  • plane:http://www.robots.ox.ac.uk/~vgg/data/airplanes_side/airplanes_side.tar (airplanes_side.tar,~43.7M)
  • house:http://www.robots.ox.ac.uk/~vgg/data/houses/houses.tar (houses.tar,~16.9M)
  • guitar:http://www.robots.ox.ac.uk/~vgg/data/guitars/guitars.tar (guitars.tar,~24.5M)

二、预训练

首先下载 tensorflow 的源码,GitHub 地址:https://github.com/tensorflow/tensorflow,解压并放在指定位置,比如 D:\TensorFlow目录下。然后写个批处理文件去执行 TensorFlow 中retrain.py程序,自动训练模型。

python F:\code\tensorflow-r1.8\tensorflow\examples\image_retraining\retrain.py ^
--bottleneck_dir bottleneck ^
--how_many_training_steps 200 ^
--model_dir inception_model/ ^
--output_graph output_graph.pb ^
--output_labels output_labels.txt ^
--image_dir data/
pause

三、测试模型

# coding: utf-8

import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt

lines = tf.gfile.GFile('F:\\code\\retrain\\output_labels.txt').readlines()
uid_to_human = {}
# 一行一行读取数据
for uid, line in enumerate(lines):
    # 去掉换行符
    line=line.strip('\n')
    uid_to_human[uid] = line

def id_to_string(node_id):
    if node_id not in uid_to_human:
        return ''
    return uid_to_human[node_id]
# 创建一个图来存放google训练好的模型
with tf.gfile.FastGFile('F:\\code\\retrain\\output_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    #遍历目录
    for root,dirs,files in os.walk('F:\\code\\retrain\\images\\'):	#指定测试图片的位置
        for file in files:
            #载入图片
            image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
            predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
            predictions = np.squeeze(predictions)#把结果转为1维数据

            #打印图片路径及名称
            image_path = os.path.join(root,file)
            print(image_path)

            #排序
            top_k = predictions.argsort()[::-1]
            print(top_k)
            for node_id in top_k:
                #获取分类名称
                human_string = id_to_string(node_id)
                #获取该分类的置信度
                score = predictions[node_id]
                print('%s (score = %.5f)' % (human_string, score))
            print()
            #显示图片
            img=Image.open(image_path)
            plt.imshow(img)
            plt.axis('off')
            plt.show()

 

你可能感兴趣的:(python深度学习,tensorflow)