Tensorflow(12)训练自己的数据

1. 去官网下载

Tensorflow(12)训练自己的数据_第1张图片
image.png

2. 网上下载一些文件,做成像我这样的

Tensorflow(12)训练自己的数据_第2张图片
image.png
  • 爬虫代码https://www.jianshu.com/p/de9936383637

3. 跑测试数据

# 用到了retrain.py 文件
python /Users/chengkai/Documents/06_code/code/tensorflow/hub-master/examples/image_retraining/retrain.py 
# 只跑最后一层
--bottleneck_dir bottleneck 
# 训练200遍
--how_many_training_steps 200 
# 用到tensorflow(11) inception 模型
--model_dir /Users/chengkai/Documents/06_code/code/tensorflow/inception_model/ 
# 输出文件
--output_graph output_graph.pb 
--output_labels output_labels.txt 
# 用来训练的图片
--image_dir /Users/chengkai/Documents/06_code/code/tensorflow/images/
image.png

4. 跑测试数据

image.png

# coding: utf-8

import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt


lines = tf.gfile.GFile('/Users/chengkai/Documents/06_code/code/tensorflow/hub-master/examples/image_retraining/output_labels.txt').readlines()
uid_to_human = {}
#一行一行读取数据
for uid,line in enumerate(lines) :
    #去掉换行符
    line=line.strip('\n')
    uid_to_human[uid] = line

def id_to_string(node_id):
    if node_id not in uid_to_human:
        return ''
    return uid_to_human[node_id]

#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile('/Users/chengkai/Documents/06_code/code/tensorflow/hub-master/examples/image_retraining/output_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')


with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    #遍历目录
    for root,dirs,files in os.walk('/Users/chengkai/Documents/06_code/code/tensorflow/hub-master/examples/image_retraining/test/'):
        for file in files:
            if file.startswith("."):
                continue
            print(file)
            #载入图片
            image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
            predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
            predictions = np.squeeze(predictions)#把结果转为1维数据

            #打印图片路径及名称
            image_path = os.path.join(root,file)
            print(image_path)
            #显示图片
            img=Image.open(image_path)
            plt.imshow(img)
            plt.axis('off')
            plt.show()

            #排序
            top_k = predictions.argsort()[::-1]
            print(top_k)
            for node_id in top_k:     
                #获取分类名称
                human_string = id_to_string(node_id)
                #获取该分类的置信度
                score = predictions[node_id]
                print('%s (score = %.5f)' % (human_string, score))
            print()
# 这一步我报错了
TypeError: Cannot interpret feed_dict key as Tensor: The name 'DecodeJpeg/contents:0' refers to a Tensor which does not exist. The operation, 'DecodeJpeg/contents', does not exist in the graph.
Tensorflow(12)训练自己的数据_第3张图片
image.png

你可能感兴趣的:(Tensorflow(12)训练自己的数据)