caffe下使用faster-rcnn测试test.txt中所有图片并保存

caffe下使用faster-rcnn测试test.txt中所有图片并保存

直接读取…\py-faster-rcnn-master\data\VOCdevkit2007\VOC2007\ImageSets\Main\test.txt中的测试集编号进行测试。
原demo.py文件:实现的是检测一张图片,然后对该图片的每一类检测结果,单独显示。
修改后:从txt中读取要检测的图片名称,进行批量检测,每张图上都有所有类的检测结果,然后保存到自定义目录。

运行代码前需要将全部图片(或者仅是测试集图片)复制到…\py-faster-rcnn-master\data\demo\目录下。

import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse
from PIL import Image

CLASSES = ('__background__', 'ali', 'ddg1000', 'jg3', 'jilong',
           )

NETS = {'vgg16': ('VGG16',
                  'VGG16_faster_rcnn_final.caffemodel'),
        'vgg1024': ('VGG_CNN_M_1024',
                    'vgg_cnn_m_1024_faster_rcnn_iter_70000.caffemodel'),
        'zf': ('ZF',
                  'ZF_faster_rcnn_final.caffemodel')}

# 此处删除原来的函数“def vis_detections(im, class_name, dets, thresh=0.5):”

def demo(net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(net, im)
    timer.toc()
    print ('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0])

    # Visualize detections for each class
    thresh = 0.7
    # CONF_THRESH = 0.7
    NMS_THRESH = 0.3

    # 打开图片
    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal', alpha=0.5)

    # 对每一类的每一个目标,在图片上生成框
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]

        # 添加以下代码
        inds = np.where(dets[:, -1] >= thresh)[0]
        if len(inds) == 0:
            continue
        for i in inds:
            bbox = dets[i, :4]
            score = dets[i, -1]
            ax.add_patch(
                plt.Rectangle((bbox[0], bbox[1]),
                              bbox[2] - bbox[0],
                              bbox[3] - bbox[1], fill=False,
                              edgecolor='red', linewidth=1.5)
                )
            ax.text(bbox[0], bbox[1] - 2,
                    '{:s} {:.3f}'.format(cls, score),
                    bbox=dict(facecolor='blue', alpha=0.5),
                    fontsize=14, color='white')
    plt.axis('off')
    plt.tight_layout()
    plt.draw()
    image_name = image_name.replace('jpg', 'png')
    # 设置保存目录
    plt.savefig('E:\jianchuan\py-faster-rcnn-master_4types_2km\save_results\save_demo_test/' + image_name)
    print("E:\jianchuan\py-faster-rcnn-master_4types_2km\save_results\save_demo_test/{}".format(image_name))

def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Faster R-CNN demo')
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
                        default=0, type=int)
    parser.add_argument('--cpu', dest='cpu_mode',
                        help='Use CPU mode (overrides --gpu)',
                        action='store_true')
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',
                        choices=NETS.keys(), default='vgg1024')

    args = parser.parse_args()

    return args

if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals

    args = parse_args()

    prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],
                            'faster_rcnn_end2end', 'test.prototxt')
    caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',
                              NETS[args.demo_net][1])

    if not os.path.isfile(caffemodel):
        raise IOError(('{:s} not found.\nDid you run ./data/script/'
                       'fetch_faster_rcnn_models.sh?').format(caffemodel))

    if args.cpu_mode:
        caffe.set_mode_cpu()
    else:
        caffe.set_mode_gpu()
        caffe.set_device(args.gpu_id)
        cfg.GPU_ID = args.gpu_id
    net = caffe.Net(prototxt, caffemodel, caffe.TEST)

    print ('\n\nLoaded network {:s}'.format(caffemodel))

    # 读取txt文件并循环检测
    fi = open('E:\jianchuan\py-faster-rcnn-master_4types_2km\data\VOCdevkit2007\VOC2007\ImageSets\Main\\test.txt')
    txt = fi.readlines()
    im_names = []
    for line in txt:
        line = line.strip('\n')
        line = line.replace('\r', '')
        line = (line + '.jpg')
        im_names.append(line)
    print(im_names)
    fi.close()
    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('Demo for data/demo/{}'.format(im_name))
        demo(net, im_name)

    # plt.show()

参考:Faster批量测试且所有类检测结果都显示在一张图上(TensorFlow实现faster-rcnn)

你可能感兴趣的:(faster-rcnn)