NASNET 是目前为止性能最好的图像分类网络,tensorflow/models 和yeephycho/nasnet-tensorflow已经公布了相应的训练和测试代码。但是提供的示例代码仅为一句命令行,封装程度太高,不便于理解和自己测试。为了实现利用尽量少的python代码自行测试,本人查阅源代码,并且在GitHub上搜寻相关问题,终于解决了这一问题。
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 避免import tensorflow时的warning
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
import scipy.io
import skimage.transform
2.然后clone models/research/slim/子目录,其方法在前面的文章中有简单的介绍,然后添加该文件的路径到python系统路径中
import sys
nets_path = r'...\slim'
sys.path.insert(0,nets_path)
from nets.nasnet import nasnet
slim = tf.contrib.slim
3.利用slim/datasets中已有的代码生成imagenet label
from datasets import imagenet
labels = imagenet.create_readable_names_for_imagenet_labels()
注意到这里的label是1001类,而正常情况下应该是1000类。这是因为对于inception等众多模型而言,其设计者特地在其中加入了’0: background’,也即所有的原有类别对应的数值都加一。这可以通过查看labels的具体内容验证
>> print(len(labels))
1001
>> print(labels)
{0: 'background', 1: 'tench, Tinca tinca', 2: 'goldfish, Carassius auratus', 3: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', 4: 'tiger shark, Galeocerdo cuvieri', 5: 'hammerhead, hammerhead shark', 6: 'electric ray, crampfish, numbfish, torpedo',...
而imagenet的1000类对应的类别分别是
tench, Tinca tinca
goldfish, Carassius auratus
great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
tiger shark, Galeocerdo cuvieri
hammerhead, hammerhead shark
electric ray, crampfish, numbfish, torpedo
可以看到1001类确实只在0上多了背景类
4.导入nasnet预训练模型。这需要事先将nasnet checkpoint文件下载下来,该模型由两种形式,一种是mobile,另一种是large。以下导入mobile(要导入large模型仅需将所有mobile改成large)
ckpt_path = r'...\nasnet-a_mobile_04_10_2017\model.ckpt'
tf.reset_default_graph()
x = tf.placeholder(tf.float32,shape = [None,224,224,3],name = 'im')
# mobile的input shape为(224,224,3),large的为(331,331,3),可以查看slim/nets/nasnet/nasnet.py得知
mean = tf.constant([[[[ 123.68/255, 116.779/255, 103.939/255]]]],name = 'im_mean',dtype = tf.float32) #像素值为[0,1]
x1 = tf.subtract(x,mean)
slim.get_or_create_global_step() #不加这一句会报错
with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
net,endpoints = nasnet.build_nasnet_mobile(images = x1,num_classes = 1000 + 1)
pass
saver = tf.train.Saver()
sess = tf.InteractiveSession()
saver.restore(sess,ckpt_path)
prob = tf.nn.softmax(net,axis = 1)
y = tf.argmax(prob,axis = 1)
5.测试
y0 = sess.run(y,{x:np.expand_dims(im,0)})[0]
print(y0,labels[y0])
事实上slim/eval_image_classifier.py
中给出了所有需要的信息,只需耐心阅读代码就不难解决