导入NASNET并测试

导入NASNET并测试

引言

NASNET 是目前为止性能最好的图像分类网络,tensorflow/models 和yeephycho/nasnet-tensorflow已经公布了相应的训练和测试代码。但是提供的示例代码仅为一句命令行,封装程度太高,不便于理解和自己测试。为了实现利用尽量少的python代码自行测试,本人查阅源代码,并且在GitHub上搜寻相关问题,终于解决了这一问题。

实现

  1. 首先导入需要的库
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 避免import tensorflow时的warning

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
import scipy.io
import skimage.transform

2.然后clone models/research/slim/子目录,其方法在前面的文章中有简单的介绍,然后添加该文件的路径到python系统路径中

import sys
nets_path = r'...\slim'
sys.path.insert(0,nets_path)

from nets.nasnet import nasnet

slim = tf.contrib.slim

3.利用slim/datasets中已有的代码生成imagenet label

from datasets import imagenet
labels = imagenet.create_readable_names_for_imagenet_labels()

注意到这里的label是1001类,而正常情况下应该是1000类。这是因为对于inception等众多模型而言,其设计者特地在其中加入了’0: background’,也即所有的原有类别对应的数值都加一。这可以通过查看labels的具体内容验证

 >> print(len(labels))
1001
 >> print(labels)
{0: 'background', 1: 'tench, Tinca tinca', 2: 'goldfish, Carassius auratus', 3: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', 4: 'tiger shark, Galeocerdo cuvieri', 5: 'hammerhead, hammerhead shark', 6: 'electric ray, crampfish, numbfish, torpedo',...

而imagenet的1000类对应的类别分别是

tench, Tinca tinca
goldfish, Carassius auratus
great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
tiger shark, Galeocerdo cuvieri
hammerhead, hammerhead shark
electric ray, crampfish, numbfish, torpedo

可以看到1001类确实只在0上多了背景类
4.导入nasnet预训练模型。这需要事先将nasnet checkpoint文件下载下来,该模型由两种形式,一种是mobile,另一种是large。以下导入mobile(要导入large模型仅需将所有mobile改成large)

ckpt_path = r'...\nasnet-a_mobile_04_10_2017\model.ckpt'

tf.reset_default_graph()
x = tf.placeholder(tf.float32,shape = [None,224,224,3],name = 'im')
# mobile的input shape为(224,224,3),large的为(331,331,3),可以查看slim/nets/nasnet/nasnet.py得知
mean = tf.constant([[[[ 123.68/255,  116.779/255,  103.939/255]]]],name = 'im_mean',dtype = tf.float32) #像素值为[0,1]
x1 = tf.subtract(x,mean)

slim.get_or_create_global_step() #不加这一句会报错
with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
    net,endpoints = nasnet.build_nasnet_mobile(images = x1,num_classes = 1000 + 1)
    pass

saver = tf.train.Saver()
sess = tf.InteractiveSession()

saver.restore(sess,ckpt_path)

prob = tf.nn.softmax(net,axis = 1)
y = tf.argmax(prob,axis = 1)

5.测试

y0 = sess.run(y,{x:np.expand_dims(im,0)})[0]
print(y0,labels[y0])

总结

事实上slim/eval_image_classifier.py中给出了所有需要的信息,只需耐心阅读代码就不难解决

你可能感兴趣的:(tensorflow)