caffe模型训练后使用python接口测试-分类模型-20200617

  1. caffe模型训练后使用python接口测试,做此记录.
    From:言有三
import caffe
import numpy as np
import cv2

def start_test(model_proto,model_weight,imgtxt,testsize,enable_crop):
   ### 初始化网络
   caffe.set_device(0) 
   net = caffe.Net(model_proto, model_weight, caffe.TEST)
   imgs = open(imgtxt,'r').readlines() 
   count = 0
   acc = 0
   for imgname in imgs:
      imgname,label = imgname.strip().split(' ')
      imgtype = imgname.split('.')[-1]
      if imgtype != 'png' and imgtype != 'jpg' and imgtype != 'JPG' and imgtype != 'jpeg' and imgtype != 'tif' and imgtype != 'bmp':
          print(imgtype,"error")
          continue
      imgpath = imgname

      img = cv2.imread(imgpath)
      if img is None:
          print("---------img is empty---------",imgpath)
          continue
  
      imgheight,imgwidth,channel = img.shape
### 选择使用裁剪或者缩放的方案
      if enable_crop == 1:
          print("use crop")
          cropx = (imgwidth - testsize) // 2 
          cropy = (imgheight - testsize) // 2
          img = img[cropy:cropy+testsize,cropx:cropx+testsize,0:channel]
      else:
          img = cv2.resize(img,(testsize,testsize),interpolation=cv2.INTER_NEAREST)

### 减均值预处理
      transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
      transformer.set_mean('data', np.array([104.008,116.669,122.675]))
      transformer.set_transpose('data', (2,0,1))  
      out = net.forward_all(data=np.asarray([transformer.preprocess('data', img)]))
         
      result = out['classifier'][0]
      print("result=",result)
      predict = np.argmax(result) 
      if str(label) == str(predict):
         acc = acc + 1
      count = count + 1
   
   print("acc=",float(acc) / float(count))

if __name__ = __main__:
   start_test('deploy.prototxt', 'models/mobilenet_finetune_iter_2000.caffemodel', 'all_shuffle_val.txt', 96, 1)

你可能感兴趣的:(caffe)