Python +OpenCV接口 应用caffe预测分类

#import caffe
#import lmdb
import numpy as np
import cv2
#from caffe.proto import caffe_pb2
import os
import sys
import time
import id_deploy
import caffe
#caffe.set_mode_gpu()

#定义遍历文件下的所有图片
def dirlist(path, allfile):
    filelist = os.listdir(path)

    for filename in filelist:
        filepath = os.path.join(path, filename)
        if os.path.isdir(filepath):
            dirlist(filepath, allfile)
        else:
            allfile.append(filepath)
    return allfile



# sys.setrecursionlimit(1000000)

#定义判断是否是一张图片
def is_bgr_img(img):
    bools = True
    try:
        a, b, c = img.shape
    except AttributeError:
        bools = False
    return bools

dirs = ['0_BM', '1_WD', '2_DO', '3_WJ', '4_WT', '5_YM', '6_CD','7_SG','8_TW','9_WZ']



imgnames = dirlist('F:zfh20181012', [])
path ='f:/out_zfh/'
caffe_model='AI_Stomach.caffemodel'   #训练好的 caffemodel

temp = imgnames[0]
#print(temp.split('\\')[-2].split('_')[0])
print(temp)

deploy_temp_file = 'deepid_capsule_protxt.prototxt'
handle_file = open(deploy_temp_file,'w')
handle_file.write(id_deploy._deepid_capsule_protxt)
handle_file.close()
#load model载入模型文件
caffenet = cv2.dnn.readNetFromCaffe(deploy_temp_file, caffe_model)
#delete deploy file
os.remove(deploy_temp_file)

#transformer = caffe.io.Transformer({'data': (1,3,128,128)})  # 设定图片的shape格式(1,3,28,28)
#transformer.set_transpose('data', (2, 0, 1)) 
for imgname in imgnames:
    image = cv2.imread(imgname)
    print(imgname)
    temp = imgname     
    try:
        image.shape
    except AttributeError:
            print(imgname)
            os.remove(imgname)
            continue  
#将图片转换成BLOB格式,各参数分别为:图像,缩放,尺寸,均值,RGB装BGR,分割                      
    blob = cv2.dnn.blobFromImage(image, 1.0, (128, 128), (0,0,0),False,False)
    caffenet.setInput(blob)
    detections = caffenet.forward('softmax')
        
    prob = detections[0]
    order=prob.argsort()[-1]
    prob_class = prob[order]
    print('the predict class is:',order)
      
    if prob_class > 0.5:
    	imgname = temp.split('\\')[-1]    	
    	imgpath = path + dirs[order]
    	if not os.path.exists(imgpath):
    		os.mkdir(imgpath)
    	cv2.imwrite(imgpath+'/'+imgname, image)
    else:
    	imgname = temp.split('\\')[-1]
    	imgpath = path + 'unkown'
    	if not os.path.exists(imgpath):
    		os.mkdir(imgpath)
    	cv2.imwrite(imgpath+'/'+imgname, image)

    cv2.imshow('cv2', image)
    k = cv2.waitKey(1)
    if k == 27:
        break
    if k == 32:
        cv2.waitKey()

 

你可能感兴趣的:(机器学习)