【Tensorflow】学习二:图像分类器

本机环境
centos6.8 + python2.7 + tensorflow0.11

下载tensorflow源码

从github上拉去代码并切换到0.11版本:

git clone https://github.com/tensorflow/tensorflow
git checkout r0.11

google-Inception模型示例

执行如下命令,利用google的inception模型识别图片space_shuttle.jpg

cd tensorflow/models/images/imagenet/
python classify_image.py --image_file /home/xiabing/TensorFlow_pics/space_shuttle.jpg

可以看到识别结果如下:


【Tensorflow】学习二:图像分类器_第1张图片
space_shuttle_result.jpg

分析classify_image.py

下面看看classify_image.py的源码
classify_image.py会首先下载分类器模型:
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
下载后会放到本地/tmp/imagenet/路径下:

【Tensorflow】学习二:图像分类器_第2张图片
inception.jpg

训练自己的分类模型

使用tensorflow中examples中的image_retraining来retraining谷歌的inception模型

准备图片数据

准备要训练的每个分类,需要有个对应的文件夹(因为每个子文件夹内的各个图片的label标签就是取分类文件夹名的)类似以下这种:

fruit/banna/
fruit/apple/

每个分类内的数据格式没有规定,本例如下:


【Tensorflow】学习二:图像分类器_第3张图片
每个分类下图片.jpg

使用retraining.py训练

调用如下命令开始训练,参数详解参见retrain.py文件:

python /home/xiabing/TensorFlow/tensorflow/tensorflow/examples/image_retraining/retrain.py --bottleneck_dir /home/xiabing/sd_classify_pics/bottleneck --how_many_training_steps 4000 --model_dir /home/xiabing/sd_classify_pics/model --output_graph /home/xiabing/sd_classify_pics/output_graph.pb --output_labels /home/xiabing/sd_classify_pics/output_labels.txt --image_dir /home/xiabing/TensorFlow_pics/fruit/

首次调用会出现如下错误:

ImportError: cannot import name graph_util

解决办法:

修改retrain.py,把
from tensorflow.python.framework import graph_util
替换为
from tensorflow.python.client import graph_util

再重新执行上面命令,看到如下打印表示训练完成:


训练完成.jpg

训练结果

训练完成后,会在当前目录下生成下面两个文件。查看标签文件,会看到banana和apple。


训练产生结果文件.jpg

labels内容.jpg

使用训练好的模型

在训练结果路径下新建test.py文件,加入如下代码:

  import tensorflow as tf
  import sys

  image_file = sys.argv[1]
  #print(image_file)

  image = tf.gfile.FastGFile(image_file, 'rb').read()

  labels = []
  for label in tf.gfile.GFile("output_labels.txt"):
      labels.append(label.rstrip())

  with tf.gfile.FastGFile("output_graph.pb", 'rb') as f:
  graph_def = tf.GraphDef()
  graph_def.ParseFromString(f.read())
  tf.import_graph_def(graph_def, name='')

  with tf.Session() as sess:
  softmax_tensor =     sess.graph.get_tensor_by_name('final_result:0')
  predict = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image})

  top = predict[0].argsort()[-len(predict[0]):][::-1]
  for index in top:
        human_string = labels[index]
        score = predict[0][index]
        print(human_string, score)         

测试训练好的模型:

python /home/xiabing/sd_classify_pics/test.py /home/xiabing/TensorFlow_pics/1510114397170.jpg

原始图片:


【Tensorflow】学习二:图像分类器_第4张图片
1510114397170.jpg

测试结果:


结果.jpg

你可能感兴趣的:(【Tensorflow】学习二:图像分类器)