TensorFlow学习笔记(七) 使用训练好的inception_v3模型预测分类图片

下载需要练习的inception模型并看起流程

  1. import tensorflow as tf  
  2. import os  
  3. import tarfile  
  4. import requests  
  5.   
  6. #inception_v3模型下载  
  7. inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'  
  8.   
  9. # 模型存放地址  
  10. inception_pretrain_model_dir = "inception_model"  
  11. if not os.path.exists(inception_pretrain_model_dir):  
  12.     os.makedirs(inception_pretrain_model_dir)  
  13.   
  14. #获取文件名,以及文件路径  
  15. filename = inception_pretrain_model_url.split('/')[-1]  
  16. filepath = os.path.join(inception_pretrain_model_dir, filename)  
  17.   
  18. #下载模型  
  19. if not os.path.exists(filepath):  
  20.     print('download: ', filename)  
  21.     r = requests.get(inception_pretrain_model_url, stream=True)  
  22.     with open(filepath,'wb') as f:  
  23.         for chunk in r.iter_content(chunk_size=1024):  
  24.             if chunk:  
  25.                 f.write(chunk)  
  26. print("finishn: ", filename)  
  27.   
  28. #解压文件  
  29. tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir)  
  30.   
  31. #模型结构存放文件  
  32. log_dir = 'inception_log'  
  33. if not os.path.exists(log_dir):  
  34.     os.makedirs(log_dir)  
  35.   
  36. #classify_image_graph_def.pb为google训练好的模型  
  37. inception_graph_def_file = os.path.join(inception_pretrain_model_dir, 'classify_image_graph_def.pb')  
  38. with tf.Session() as sess:  
  39.     #创建一个图来存放google训练好的模型,load graph 具体实现方法看下面的链接  
  40.     with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f:  
  41.         graph_def = tf.GraphDef()  
  42.         graph_def.ParseFromString(f.read())  
  43.         tf.import_graph_def(graph_def, name='')  
  44.   
  45.     #保存图的结构  
  46.     writer = tf.summary.FileWriter(log_dir, sess.graph)  
  47.     writer.close() 
其结构图孺如下
TensorFlow学习笔记(七) 使用训练好的inception_v3模型预测分类图片_第1张图片

TensorFlow学习笔记(七) 使用训练好的inception_v3模型预测分类图片_第2张图片



使用inception模型检测图片

里面主要写了labels排序等 的实现,以及利用训练好的模型识别图片的实现


  1. import tensorflow as tf  
  2. import os  
  3. import numpy as np  
  4. import re  
  5. from PIL import Image  
  6. import matplotlib.pyplot as plt  
  7.   
  8. class NodeLookup(object):  
  9.     def __init__(self):  
  10.         label_lookup_path = 'inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'  
  11.         uid_lookup_path = 'inception_model/imagenet_synset_to_human_label_map.txt'  
  12.         self.node_lookup = self.load(label_lookup_path, uid_lookup_path)  
  13.   
  14.     def load(self, label_lookup_path, uid_lookup_path):  
  15.         #加载分类字符串n ------ 对应分类名称的文件  
  16.         proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()  
  17.         uid_to_human = {}  
  18.         #一行一行读取数据  
  19.         for line in proto_as_ascii_lines :  
  20.             #去掉换行符  
  21.             line = line.strip('\n')  
  22.             #按照‘\t’分割  
  23.             parsed_items = line.split('\t')  
  24.             #获取分类编号和分类名称  
  25.             uid = parsed_items[0]  
  26.             human_string = parsed_items[1]  
  27.             #保存编号字符串-----与分类名称映射关系  
  28.             uid_to_human[uid] = human_string  
  29.   
  30.   
  31.         #加载分类字符串n ----- 对应分类编号1-1000的文件  
  32.         proto_as_ascii_lines = tf.gfile.GFile(label_lookup_path).readlines()  
  33.         node_id_to_uid = {}  
  34.         for line in proto_as_ascii_lines :  
  35.             if line.startswith('  target_class:'):  
  36.                 #获取分类编号1-1000  
  37.                 target_class = int(line.split(': ')[1])  
  38.             if line.startswith('  target_class_string:'):  
  39.                 #获取编号字符串n****  
  40.                 target_class_string = line.split(': ')[1]  
  41.                 #保存分类编号1-1000与编号字符串n****的映射关系  
  42.                 node_id_to_uid[target_class] = target_class_string[1:-2]  
  43.   
  44.   
  45.         #建立分类编号1-1000对应分类名称的映射关系  
  46.         node_id_to_name = {}  
  47.         for key, val in node_id_to_uid.items():  
  48.             #获取分类名称  
  49.             name = uid_to_human[val]  
  50.             #建立分类编号1-1000到分类名称的映射关系  
  51.             node_id_to_name[key] = name  
  52.         return node_id_to_name  
  53.   
  54.     #传入分类编号1-1000返回分类名称  
  55.     def id_to_string(self, node_id):  
  56.         if node_id not in self.node_lookup:  
  57.             return ''  
  58.         return self.node_lookup[node_id]  
  59.   
  60. #创建一个图来存放google训练好的模型  #2 load graph  
  61. with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb', 'rb') as f:  
  62.     graph_def = tf.GraphDef()  
  63.     graph_def.ParseFromString(f.read())  
  64.     tf.import_graph_def(graph_def, name='')  
  65.   
  66.   
  67. with tf.Session() as sess:  
  68.     softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')  
  69.     #遍历目录  
  70.     for root, dirs, files in os.walk('images/'):  
  71.         for file in files:  
  72.             #载入图片  
  73.             image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()  
  74.             predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式  
  75.             predictions = np.squeeze(predictions)#把结果转为1维  
  76.   
  77.             #打印图片路径及名称  
  78.             image_path = os.path.join(root,file)  
  79.             print(image_path)  
  80.             #显示图片  
  81.             img = Image.open(image_path)  
  82.             plt.imshow(img)  
  83.             plt.axis('off')  
  84.             plt.show()  
  85.   
  86.             #排序  
  87.             top_k = predictions.argsort()[-5:][::-1]  
  88.             node_lookup = NodeLookup()  
  89.             for node_id in top_k:  
  90.                 #获取分类名称  
  91.                 human_string = node_lookup.id_to_string(node_id)  
  92.                 #获取该分类的置信度  
  93.                 score = predictions[node_id]  
  94.                 print('%s (score = %.5f)' % (human_string, score))  
  95.             print()  

TensorFlow学习笔记(七) 使用训练好的inception_v3模型预测分类图片_第3张图片TensorFlow学习笔记(七) 使用训练好的inception_v3模型预测分类图片_第4张图片


TensorFlow学习笔记(七) 使用训练好的inception_v3模型预测分类图片_第5张图片


TensorFlow学习笔记(七) 使用训练好的inception_v3模型预测分类图片_第6张图片

识别的准确率还是可以的,如果遇到不认识的狗子,识别一下就好啦,哈哈


你可能感兴趣的:(TensorFlow学习笔记(七) 使用训练好的inception_v3模型预测分类图片)