原文链接: TensorFlow models 的slim 模块 使用预训练模型进行识别
上一篇: scrapy 斗鱼 主播信息爬取
下一篇: TensorFlow vgg 预训练模型使用
下载
https://github.com/tensorflow/models
解压后
在该文件夹内执行,下载flowers数据集并转化为tfrecord格式保存到指定目录中
python download_and_convert_data.py --dataset_name=flowers --dataset_dir=/tmp/data/flowers
将slim ,datasets ,nets, preprocessing 复制到D:\ProgramData\Anaconda3\Lib\site-packages下,方便在项目中引入
可以看到,有两个数据集,一个train和validation,在读取时,需要指定一个数据集,然后创建provider对象,接着就可以从provider里读取数据了,读取默认是随机的,可以加入
shuffle 是否打乱顺序,
num_epochs 循环次数,默认为None表示无限循环,
num_readers并行读取器数目,
common_queue_capacity 队列大小
等参数
import tensorflow as tf
from datasets import flowers
import pylab
slim = tf.contrib.slim
DATA_DIR = "D:/tmp/data/flowers"
# 选择数据集validation
dataset = flowers.get_split('validation', DATA_DIR)
# 创建一个provider
provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
# 通过provider的get拿到内容
[image, label] = provider.get(['image', 'label'])
print(image.shape)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# 启动队列
tf.train.start_queue_runners()
# 获取数据
image_batch, label_batch = sess.run([image, label])
# 显示
print(label_batch)
pylab.imshow(image_batch)
pylab.show()
使用预训练好的模型进行图像识别 下载 Inception-ResNet-v2
下载地址
https://github.com/tensorflow/models/tree/master/research/slim/#Pretrained
下载后解压保存
使用该网络进行识别
import tensorflow as tf
from PIL import Image
from matplotlib import pyplot as plt
from nets import inception
import numpy as np
from datasets import imagenet
tf.reset_default_graph()
image_size = inception.inception_resnet_v2.default_image_size
names = imagenet.create_readable_names_for_imagenet_labels()
slim = tf.contrib.slim
checkpoint_file = 'point/inception_resnet_v2_2016_08_30.ckpt'
sample_images = ['test.jpg', 'fish.jpg', 'img.jpg']
input_imgs = tf.placeholder("float", [None, image_size, image_size, 3])
# Load the model
sess = tf.Session()
arg_scope = inception.inception_resnet_v2_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = inception.inception_resnet_v2(input_imgs, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
for image in sample_images:
reimg = Image.open(image).resize((image_size, image_size))
reimg = np.array(reimg)
reimg = reimg.reshape(-1, image_size, image_size, 3)
plt.figure()
p1 = plt.subplot(121)
p2 = plt.subplot(122)
p1.imshow(reimg[0]) # 显示图片
p1.axis('off')
p1.set_title("organization image")
reimg_norm = 2 * (reimg / 255.0) - 1.0
p2.imshow(reimg_norm[0]) # 显示图片
p2.axis('off')
p2.set_title("input image")
plt.show()
predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_imgs: reimg_norm})
print(np.max(predict_values), np.max(logit_values))
print(np.argmax(predict_values), np.argmax(logit_values), names[np.argmax(logit_values)])
识别结果
0.27505645 8.079795
809 809 sombrero # 宽边帽
0.25394952 6.9343405
768 768 rubber eraser, rubber, pencil eraser