caffe模型调用预测示例

如下给出调用已经训练好的caffe模型进行测试的示例代码:

包括:模型调用,时间统计,结果写入txt文件。

import os
import caffe
import numpy as np
import time

caffe.set_mode_cpu()

root = '/home/xuqiong/makeall/caffe/'
deploy = root + 'examples/xq0523pm/shufflenet_deploy.prototxt'
caffe_model = root + 'examples/xq0523pm/shufflenet_train_iter_50000.caffemodel'
#mean_file = '/mnt/data2/xuqiong/data/split/mean.binaryprototxt

test = '/mnt/data2/xuqiong/data/split/test/'
filelist = []

dir = test + 'pristine_images/'
filenames = os.listdir(dir)
for fn in filenames:
    fullfilename = os.path.join(dir, fn)
    filelist.append(fullfilename)

count = 0
timei = 0
timef = 0
timem = 0
timea = 0
timei_all = 0
timef_all = 0
timem_all = 0
timea_all = 0
for i in range(0, len(filelist)):
    img = filelist[i]
    #Test(img)

    caffe.set_mode_cpu()
    net = caffe.Net(deploy, caffe_model, caffe.TEST)

    transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})  # shape(1,3,28,28)
    transformer.set_transpose('data', (2, 0, 1))  # (28,28,3) to (3,28,28)
    #transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))  # minus mean
    transformer.set_raw_scale('data', 255)  # [0,255]
    transformer.set_channel_swap('data', (2, 1, 0))  # RGB2BGR

    time0 = time.time()
    im = caffe.io.load_image(img)
    net.blobs['data'].data[...] = transformer.preprocess('data', im)
    time1 = time.time()

    out = net.forward()
    time2 = time.time()

    prob = net.blobs['fcout6'].data[0].flatten()  # last layer(prob)
    order = prob.argsort()[5]  # argsort() small-big
    time3 = time.time()

    timei = time1 - time0
    timef = time2 - time1
    timem = time3 - time2
    timea = timei + timef + timem
    timei_all = timei + timei_all
    timef_all = timef + timef_all
    timem_all = timem + timem_all
    timea_all = timea + timea_all
    print("timei: ", timei*1000, "ms")
    print("timef: ", timef*1000, "ms")
    print("timem: ", timem*1000, "ms")
    print("timea: ", timea*1000, "ms")
    print("the class is: ", order)

    f = open("/home/xuqiong/makeall/caffe/examples/xq0523pm/label.txt", "a+")
    f.writelines(img + ' ' + str(order) + '\n')

    #calcute accuracy
    path = img.split('/')[-2]
    if(path == 'pristine_images'):
        if order == 5:
            count = count + 1
    elif(path[-1] == 1):
        if order == 4:
            count = count + 1
    elif (path[-1] == 2):
        if order == 3:
            count = count + 1
    elif (path[-1] == 3):
        if order == 2:
            count = count + 1
    elif (path[-1] == 4):
        if order == 1:
            count = count + 1
    elif (path[-1] == 5):
        if order == 0:
            count = count + 1

print("shufflenetv1, cpu")
print("ok count: ", count)
print("all count: ", len(filelist))
print("timei average: %.2f", timei_all*1000/len(filelist), "ms")
print("timef average: %.2f", timef_all*1000/len(filelist), "ms")
print("timem average: %.2f", timem_all*1000/len(filelist), "ms")
print("timea average: %.2f", timea_all*1000/len(filelist), "ms")

f.close()

 

你可能感兴趣的:(caffe)