【caffe源码研究】第二章:实战篇(2) : ImageNet分类

本实战运行Classifying Imagenet的接口。

1. 下载model

新建一个download_model.bat,在caffe-master的工作目录下,内容

./scripts/download_model_binary.py models/bvlc_reference_caffenet 

下载bvlc_reference_caffenet.caffemodel

2. 分类

ClassifyingImageNet.py如下

__author__ = 'frank'

import numpy as np
import caffe
import matplotlib.pyplot as plt

caffe_root = '/home/fangjin/caffe/'

model_file = caffe_root +'models/bvlc_reference_caffenet/deploy.prototxt'

# mean_file = caffe_root + 'models/bvlc_reference_caffenet/deploy.prototxt'
pretrain_file = caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'
image_file = caffe_root +'examples/images/cat.jpg'

net = caffe.Classifier(model_file,pretrain_file,
                       mean = np.load(caffe_root+'python/caffe/imagenet/ilsvrc_2012_mean.npy').mean(1).mean(1),
                       channel_swap=(2,1,0),
                       raw_scale = 255,
                       image_dims=(256,256)
                       )

caffe.set_mode_cpu()

input_image = caffe.io.load_image(image_file)
plt.imshow(input_image)
plt.show()

prediction = net.predict([input_image])
print 'prediction shape:',prediction[0].shape
plt.plot(prediction[0])
print 'prediction class:',prediction[0].argmax()

plt.show()

输出

prediction shape: (1000L,)
prediction class: 281

【caffe源码研究】第二章:实战篇(2) : ImageNet分类_第1张图片

【caffe源码研究】第二章:实战篇(2) : ImageNet分类_第2张图片

3. 词袋

可以将这个分类对应词袋,下载词袋。
运行caffe/data/ilsvrc12/get_ilsvrc_aux.sh

#!/usr/bin/env sh
#
# N.B. This does not download the ilsvrcC12 data set, as it is gargantuan.
# This script downloads the imagenet example auxiliary files including:
# - the ilsvrc12 image mean, binaryproto
# - synset ids and words
# - Python pickle-format data of ImageNet graph structure and relative infogain
# - the training splits with labels

DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd $DIR

echo "Downloading..."

wget -c http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz

echo "Unzipping..."

tar -xf caffe_ilsvrc12.tar.gz && rm -f caffe_ilsvrc12.tar.gz

echo "Done."

如果内网连接不上,可以从脚本看到下载链接是 http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz
下载后解压。
添加代码

labels_file = "/home/fangjin/caffe/data/ilsvrc12/synset_words.txt"
labels = np.loadtxt(labels_file,str,delimiter='\t')
print 'output label:',labels[prediction.argmax()]

运行结果

output label: n02123045 tabby, tabby cat

可以换成自己的照片,如lena

【caffe源码研究】第二章:实战篇(2) : ImageNet分类_第3张图片

运行结果

output label: n02869837 bonnet, poke bonnet

也就是阔边女帽的意思,还是比较准确的。
也可以输出他的前几个描述。
修改为

labels_file = "/home/fangjin/caffe/data/ilsvrc12/synset_words.txt"
labels = np.loadtxt(labels_file,str,delimiter='\t')
label_inds = prediction[0].argsort()[::-1][:5]
print 'output label:',zip(prediction[0][label_inds], labels[label_inds])

Cat图像输出为

output label: [(0.27352503, 'n02123045 tabby, tabby cat'), (0.23680504, 'n02123159 tiger cat'), (0.15174942, 'n02124075 Egyptian cat'), (0.12903579, 'n02127052 lynx, catamount'), (0.047049168, 'n02119789 kit fox, Vulpes macrotis')]

Lena输出为

output label: [(0.20391516, 'n02869837 bonnet, poke bonnet'), (0.1764992, 'n04259630 sombrero'), (0.1642044, 'n03124170 cowboy hat, ten-gallon hat'), (0.058862317, 'n04584207 wig'), (0.050728988, "n02669723 academic gown, academic robe, judge's robe")]

你可能感兴趣的:(Machine,Learning,Deep,Learning,Caffe)