本实战运行Classifying Imagenet的接口。
新建一个download_model.bat
,在caffe-master的工作目录下,内容
./scripts/download_model_binary.py models/bvlc_reference_caffenet
下载bvlc_reference_caffenet.caffemodel
写ClassifyingImageNet.py
如下
__author__ = 'frank'
import numpy as np
import caffe
import matplotlib.pyplot as plt
caffe_root = '/home/fangjin/caffe/'
model_file = caffe_root +'models/bvlc_reference_caffenet/deploy.prototxt'
# mean_file = caffe_root + 'models/bvlc_reference_caffenet/deploy.prototxt'
pretrain_file = caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'
image_file = caffe_root +'examples/images/cat.jpg'
net = caffe.Classifier(model_file,pretrain_file,
mean = np.load(caffe_root+'python/caffe/imagenet/ilsvrc_2012_mean.npy').mean(1).mean(1),
channel_swap=(2,1,0),
raw_scale = 255,
image_dims=(256,256)
)
caffe.set_mode_cpu()
input_image = caffe.io.load_image(image_file)
plt.imshow(input_image)
plt.show()
prediction = net.predict([input_image])
print 'prediction shape:',prediction[0].shape
plt.plot(prediction[0])
print 'prediction class:',prediction[0].argmax()
plt.show()
输出
prediction shape: (1000L,)
prediction class: 281
可以将这个分类对应词袋,下载词袋。
运行caffe/data/ilsvrc12/get_ilsvrc_aux.sh
#!/usr/bin/env sh
#
# N.B. This does not download the ilsvrcC12 data set, as it is gargantuan.
# This script downloads the imagenet example auxiliary files including:
# - the ilsvrc12 image mean, binaryproto
# - synset ids and words
# - Python pickle-format data of ImageNet graph structure and relative infogain
# - the training splits with labels
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd $DIR
echo "Downloading..."
wget -c http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz
echo "Unzipping..."
tar -xf caffe_ilsvrc12.tar.gz && rm -f caffe_ilsvrc12.tar.gz
echo "Done."
如果内网连接不上,可以从脚本看到下载链接是 http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz
下载后解压。
添加代码
labels_file = "/home/fangjin/caffe/data/ilsvrc12/synset_words.txt"
labels = np.loadtxt(labels_file,str,delimiter='\t')
print 'output label:',labels[prediction.argmax()]
运行结果
output label: n02123045 tabby, tabby cat
可以换成自己的照片,如lena
运行结果
output label: n02869837 bonnet, poke bonnet
也就是阔边女帽的意思,还是比较准确的。
也可以输出他的前几个描述。
修改为
labels_file = "/home/fangjin/caffe/data/ilsvrc12/synset_words.txt"
labels = np.loadtxt(labels_file,str,delimiter='\t')
label_inds = prediction[0].argsort()[::-1][:5]
print 'output label:',zip(prediction[0][label_inds], labels[label_inds])
Cat图像输出为
output label: [(0.27352503, 'n02123045 tabby, tabby cat'), (0.23680504, 'n02123159 tiger cat'), (0.15174942, 'n02124075 Egyptian cat'), (0.12903579, 'n02127052 lynx, catamount'), (0.047049168, 'n02119789 kit fox, Vulpes macrotis')]
Lena输出为
output label: [(0.20391516, 'n02869837 bonnet, poke bonnet'), (0.1764992, 'n04259630 sombrero'), (0.1642044, 'n03124170 cowboy hat, ten-gallon hat'), (0.058862317, 'n04584207 wig'), (0.050728988, "n02669723 academic gown, academic robe, judge's robe")]