TensorFlow models 的slim 模块 使用预训练模型进行识别

原文链接: TensorFlow models 的slim 模块 使用预训练模型进行识别

上一篇: scrapy 斗鱼 主播信息爬取

下一篇: TensorFlow vgg 预训练模型使用

下载

https://github.com/tensorflow/models

解压后

TensorFlow models 的slim 模块 使用预训练模型进行识别_第1张图片

在该文件夹内执行,下载flowers数据集并转化为tfrecord格式保存到指定目录中

python download_and_convert_data.py --dataset_name=flowers --dataset_dir=/tmp/data/flowers

TensorFlow models 的slim 模块 使用预训练模型进行识别_第2张图片

将slim ,datasets  ,nets,  preprocessing 复制到D:\ProgramData\Anaconda3\Lib\site-packages下,方便在项目中引入

可以看到,有两个数据集,一个train和validation,在读取时,需要指定一个数据集,然后创建provider对象,接着就可以从provider里读取数据了,读取默认是随机的,可以加入

shuffle 是否打乱顺序,

num_epochs 循环次数,默认为None表示无限循环,

num_readers并行读取器数目,

common_queue_capacity 队列大小

等参数

import tensorflow as tf
from datasets import flowers
import pylab

slim = tf.contrib.slim

DATA_DIR = "D:/tmp/data/flowers"

# 选择数据集validation
dataset = flowers.get_split('validation', DATA_DIR)

# 创建一个provider
provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
# 通过provider的get拿到内容
[image, label] = provider.get(['image', 'label'])
print(image.shape)

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# 启动队列
tf.train.start_queue_runners()
# 获取数据
image_batch, label_batch = sess.run([image, label])
# 显示
print(label_batch)
pylab.imshow(image_batch)
pylab.show()

TensorFlow models 的slim 模块 使用预训练模型进行识别_第3张图片

使用预训练好的模型进行图像识别  下载 Inception-ResNet-v2

下载地址

https://github.com/tensorflow/models/tree/master/research/slim/#Pretrained

TensorFlow models 的slim 模块 使用预训练模型进行识别_第4张图片

下载后解压保存

使用该网络进行识别

import tensorflow as tf

from PIL import Image
from matplotlib import pyplot as plt
from nets import inception
import numpy as np
from datasets import imagenet

tf.reset_default_graph()
image_size = inception.inception_resnet_v2.default_image_size
names = imagenet.create_readable_names_for_imagenet_labels()

slim = tf.contrib.slim

checkpoint_file = 'point/inception_resnet_v2_2016_08_30.ckpt'
sample_images = ['test.jpg', 'fish.jpg', 'img.jpg']

input_imgs = tf.placeholder("float", [None, image_size, image_size, 3])

# Load the model
sess = tf.Session()
arg_scope = inception.inception_resnet_v2_arg_scope()

with slim.arg_scope(arg_scope):
    logits, end_points = inception.inception_resnet_v2(input_imgs, is_training=False)

saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)

for image in sample_images:
    reimg = Image.open(image).resize((image_size, image_size))
    reimg = np.array(reimg)
    reimg = reimg.reshape(-1, image_size, image_size, 3)

    plt.figure()
    p1 = plt.subplot(121)
    p2 = plt.subplot(122)

    p1.imshow(reimg[0])  # 显示图片
    p1.axis('off')
    p1.set_title("organization image")

    reimg_norm = 2 * (reimg / 255.0) - 1.0

    p2.imshow(reimg_norm[0])  # 显示图片
    p2.axis('off')
    p2.set_title("input image")

    plt.show()

    predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_imgs: reimg_norm})

    print(np.max(predict_values), np.max(logit_values))
    print(np.argmax(predict_values), np.argmax(logit_values), names[np.argmax(logit_values)])

识别结果

TensorFlow models 的slim 模块 使用预训练模型进行识别_第5张图片

0.27505645 8.079795
809 809 sombrero  # 宽边帽

TensorFlow models 的slim 模块 使用预训练模型进行识别_第6张图片

0.25394952 6.9343405
768 768 rubber eraser, rubber, pencil eraser

你可能感兴趣的:(图像识别,tensorflow,python,深度学习,机器学习)