Caffe windows下入门级别的从编译到训练然后到分类(用python接口)测试自己的图片数据(四)

前言

前面我们已经训练好了模型这时候模型文件是这样的:

Caffe windows下入门级别的从编译到训练然后到分类(用python接口)测试自己的图片数据(四)_第1张图片
我这里演示的所以设置的迭代次数不多的。

测试自己的数据

转换均值文件(将二进制的均值文件转换为npy这里要使用的均值文件):
import numpy as np
import caffe
import sys

BinaryMeanPath = '~~~~~~~~~~/mean.binaryproto'
NpyMeanOuPath = '~~~~~~~~~~/meannpy.npy'


print 'Start.............'
blob = caffe.proto.caffe_pb2.BlobProto()
data = open( BinaryMeanPath , 'rb' ).read()
blob.ParseFromString(data)
arr = np.array( caffe.io.blobproto_to_array(blob) )
out = arr[0]
np.save( NpyMeanOuPath , out )
print 'Complete.............'
然后caffe-windows\models\bvlc_reference_caffenet\这个文件夹中取出deploy这个文件,修改最后的输出数量改为我们自己的分类数量
最后,这里测试自己的数据的代码是直接从官网上拿下来的:
# coding:utf-8
import numpy as np

MyCaffeRoot = '~~~~~~~~~~/mymnist/'
ImgTestPath = '~~~~~~~~~~/1/pic_hashiqi_Pos120.jpg' #测试图片路径
LabelsPath = MyCaffeRoot + 'labels.txt'

import sys
import caffe
import os

CaffeModelPath = MyCaffeRoot + 'caffenet_train_iter_4500.caffemodel'
DeployPath = MyCaffeRoot + 'deploy.prototxt'
NpyMeanPath = '~~~~~~~~~~/mean/meannpy.npy'

if os.path.exists(CaffeModelPath) == False:
    print u'找不到模型的路径'
else:
    print u'找到模型的路径......'

caffe.set_mode_cpu();

net = caffe.Net(DeployPath, CaffeModelPath, caffe.TEST)  #创建网络
#负载均衡减去均值
mu = np.load(NpyMeanPath)
mu = mu.mean(1).mean(1)
print u'各个颜色通道的均值:', zip('BGR', mu)

transformer = caffe.io.Transformer({'data':net.blobs['data'].data.shape})
transformer.set_transpose('data',(2, 0, 1))
transformer.set_mean('data',mu)
transformer.set_raw_scale('data', 255)
transformer.set_channel_swap('data',(2, 1, 0))
net.blobs['data'].reshape(50, 3, 227, 227)

#执行测试
out = net.forward()

# transform it and copy it into the net
image = caffe.io.load_image(ImgTestPath)
net.blobs['data'].data[...] = transformer.preprocess('data', image)

# perform classification
net.forward()

# obtain the output probabilities
output_prob = net.blobs['prob'].data[0]

#验证标签文件是否存在
if os.path.exists(LabelsPath) == False:
    print u'标签文件不存在'
    exit(0)
#读取标签文件
labels = np.loadtxt(LabelsPath, str, delimiter='\t')
sort = output_prob.argsort()[::-1][:2]
#labels[0] = '哈士奇'
print output_prob
print u'这个是--->:' , str(labels[sort[0]]).decode('utf-8')
print u'这个是--->:' , str(labels[sort[1]]).decode('utf-8')

本教程到此结束,如果有哪里错误的欢迎指出来我会及时修改的。

你可能感兴趣的:(windows,数据,图片,caffe,测试,caffe)