py-faster-rcnn_caffemodel对人脸进行标注

本程序在py-faster-rcnn/tools/demo.py的基础上进行修改

程序功能:利用训练好的caffemodel,对人脸进行标注

[python]  view plain  copy
  1. #!/usr/bin/env python  
  2.   
  3. # --------------------------------------------------------  
  4. # Faster R-CNN  
  5. # Copyright (c) 2015 Microsoft  
  6. # Licensed under The MIT License [see LICENSE for details]  
  7. # Written by Ross Girshick  
  8. # --------------------------------------------------------  
  9.   
  10. """ 
  11. Demo script showing detections in sample images. 
  12.  
  13. See README.md for installation instructions before running. 
  14. """  
  15.   
  16. import _init_paths  
  17. from fast_rcnn.config import cfg  
  18. from fast_rcnn.test import im_detect  
  19. from fast_rcnn.nms_wrapper import nms  
  20. from utils.timer import Timer  
  21. import matplotlib.pyplot as plt  
  22. import numpy as np  
  23. import scipy.io as sio  
  24. import caffe, os, sys, cv2  
  25. import argparse  
  26.   
  27. #CLASSES = ('__background__',  
  28. #           'aeroplane', 'bicycle', 'bird', 'boat',  
  29. #           'bottle', 'bus', 'car', 'cat', 'chair',  
  30. #           'cow', 'diningtable', 'dog', 'horse',  
  31. #           'motorbike', 'person', 'pottedplant',  
  32. #           'sheep', 'sofa', 'train', 'tvmonitor')  
  33.   
  34. CLASSES = ('__background__','face')  
  35.   
  36. NETS = {'vgg16': ('VGG16',  
  37.                   'VGG16_faster_rcnn_final.caffemodel'),  
  38.         'myvgg': ('VGG_CNN_M_1024',  
  39.                   'VGG_CNN_M_1024_faster_rcnn_final.caffemodel'),  
  40.         'zf': ('ZF',  
  41.                   'ZF_faster_rcnn_final.caffemodel'),  
  42.         'myzf': ('ZF',  
  43.                   'zf_rpn_stage1_iter_80000.caffemodel'),  
  44. }  
  45.   
  46.   
  47. def vis_detections(im, class_name, dets, thresh=0.5):  
  48.     """Draw detected bounding boxes."""  
  49.     inds = np.where(dets[:, -1] >= thresh)[0]  
  50.     if len(inds) == 0:  
  51.         return  
  52.   
  53.     #write_file.write(array[current_image] + ' ') #add by zhipeng  
  54.     #write_file.write('face' + ' ') #add by zhipeng  
  55.     im = im[:, :, (210)]  
  56.     #fig, ax = plt.subplots(figsize=(12, 12))  
  57.     #ax.imshow(im, aspect='equal')  
  58.     for i in inds:  
  59.         bbox = dets[i, :4]  
  60.         score = dets[i, -1]  
  61.   
  62.         write_file.write(array[current_image] + ' '#add by zhipeng  
  63.         #write_file.write('face' + ' ')  
  64.         ##########   add by zhipeng for write rectange to txt   ########  
  65.         #bbox[0]:x, bbox[1]:y, bbox[2]:x+w, bbox[3]:y+h  
  66.         write_file.write( "{} {} {} {}\n".format(str(int(bbox[0])), str(int(bbox[1])),  
  67.                                                         str(int(bbox[2])-int(bbox[0])),  
  68.                                                         str(int(bbox[3])-int(bbox[1]))))  
  69.         #print "zhipeng, bbox:", bbox, "score:",score  
  70.         ##########   add by zhipeng for write rectange to txt   ########  
  71.   
  72.           
  73.   
  74. def demo(net, image_name):  
  75.     """Detect object classes in an image using pre-computed object proposals."""  
  76.   
  77.     # Load the demo image  
  78.     #im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)  
  79.     im = cv2.imread(image_name)  
  80.   
  81.     # Detect all object classes and regress object bounds  
  82.     timer = Timer()  
  83.     timer.tic()  
  84.     scores, boxes = im_detect(net, im)  
  85.     timer.toc()  
  86.     print ('Detection took {:.3f}s for '  
  87.            '{:d} object proposals').format(timer.total_time, boxes.shape[0])  
  88.   
  89.     # Visualize detections for each class  
  90.     CONF_THRESH = 0.8  
  91.     NMS_THRESH = 0.3  
  92.     for cls_ind, cls in enumerate(CLASSES[1:]):  
  93.         cls_ind += 1 # because we skipped background  
  94.         cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]  
  95.         cls_scores = scores[:, cls_ind]  
  96.         dets = np.hstack((cls_boxes,  
  97.                           cls_scores[:, np.newaxis])).astype(np.float32)  
  98.         keep = nms(dets, NMS_THRESH)  
  99.         dets = dets[keep, :]  
  100.         vis_detections(im, cls, dets, thresh=CONF_THRESH)  
  101.   
  102. def parse_args():  
  103.     """Parse input arguments."""  
  104.     parser = argparse.ArgumentParser(description='Faster R-CNN demo')  
  105.     parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',  
  106.                         default=0, type=int)  
  107.     parser.add_argument('--cpu', dest='cpu_mode',  
  108.                         help='Use CPU mode (overrides --gpu)',  
  109.                         action='store_true')  
  110.     parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',  
  111.                         choices=NETS.keys(), default='vgg16')  
  112.   
  113.     args = parser.parse_args()  
  114.   
  115.     return args  
  116.   
  117. if __name__ == '__main__':  
  118.     cfg.TEST.HAS_RPN = True  # Use RPN for proposals  
  119.   
  120.     args = parse_args()  
  121.   
  122.     prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],  
  123.                             'faster_rcnn_alt_opt''faster_rcnn_test.pt')  
  124.     caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',  
  125.                               NETS[args.demo_net][1])  
  126.   
  127.     if not os.path.isfile(caffemodel):  
  128.         raise IOError(('{:s} not found.\nDid you run ./data/script/'  
  129.                        'fetch_faster_rcnn_models.sh?').format(caffemodel))  
  130.   
  131.     if args.cpu_mode:  
  132.         caffe.set_mode_cpu()  
  133.     else:  
  134.         caffe.set_mode_gpu()  
  135.         caffe.set_device(args.gpu_id)  
  136.         cfg.GPU_ID = args.gpu_id  
  137.     net = caffe.Net(prototxt, caffemodel, caffe.TEST)  
  138.   
  139.     print '\n\nLoaded network {:s}'.format(caffemodel)  
  140.   
  141.     # Warmup on a dummy image  
  142.     im = 128 * np.ones((3005003), dtype=np.uint8)  
  143.     for i in xrange(2):  
  144.         _, _= im_detect(net, im)  
  145.   
  146.     '''''im_names = ['000456.jpg', '000542.jpg', '001150.jpg', 
  147.                 '001763.jpg', '004545.jpg']'''  
  148.   
  149.     ##########   add by zhipeng for write rectange to txt   ########  
  150.     #write_file_name = '/home/xiao/code/py-faster-rcnn-master/py-faster-rcnn/tools/detections/out.txt'  
  151.     #write_file = open(write_file_name, "w")  
  152.     ##########   add by zhipeng for write rectange to txt   ########  
  153.   
  154. #    for current_file in range(1,11):      #orginal range(1, 11)  
  155.   
  156. #    print 'Processing file ' + str(current_file) + ' ...'  
  157.   
  158.     read_file_name = '/home/xiao/code/py-faster-rcnn-master/py-faster-rcnn/data/pos_fold/name.txt'  
  159.     write_file_name = '/home/xiao/code/py-faster-rcnn-master/py-faster-rcnn/data/pos_fold/annotate.txt'  
  160.     write_file = open(write_file_name, "w")  
  161.   
  162.     with open(read_file_name, "r") as ins:  
  163.         array = []  
  164.         for line in ins:  
  165.             line = line[:-1]  
  166.             array.append(line)      # list of strings  
  167.   
  168.     number_of_images = len(array)  
  169.   
  170.     for current_image in range(number_of_images):  
  171.         if current_image % 100 == 0:  
  172.             print 'Processing image : ' + str(current_image)  
  173.         # load image and convert to gray  
  174.         read_img_name = '/home/xiao/code/py-faster-rcnn-master/py-faster-rcnn/data/pos/' + array[current_image].rstrip()  
  175.         #write_file.write(array[current_image]) #add by zhipeng  
  176.         demo(net, read_img_name)  
  177.   
  178.     write_file.close()  
  179.   
  180.     '''''for im_name in im_names: 
  181.         print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' 
  182.         print 'Demo for data/demo/{}'.format(im_name) 
  183.         write_file.write(im_name + '\n') #add by zhipeng 
  184.         demo(net, im_name)'''  
  185.   
  186.     #write_file.close()  # add by zhipeng,close file  
  187.     plt.show()  

你可能感兴趣的:(人脸识别)