pvanet训练网络时的一些小技巧

记录一些零碎的tips,但是在训练不同数据时经常用到的点。

1.关于学习率

solver.txt中记录了一些参数:
train_net: "models/pvanet/example_train_384/train.prototxt"
base_lr: 0.001
lr_policy: "step"
gamma: 0.1         # 学习策略: 每 stepsize 次迭代之后,将 α 乘以 gamma
stepsize: 50000    # 每 100K 次迭代,降低学习速率
display: 1         # 每多少次迭代显示一次损失      
average_loss: 100
momentum: 0.9      # 动量 momentum 为:μ = 0.9
weight_decay: 0.0002


# We disable standard caffe solver snapshotting and implement our own snapshot
# function
snapshot: 0 #每迭代多少次会保存一次caffemodel至pvanet目录,和config.py中的snapshot是分开保存的
# We still use the snapshot prefix, though
snapshot_prefix: "pvanet_frcnn_384"
iter_size: 2
caffe教程中解释的很明确:随机梯度下降(Stochastic gradient descent, type:”SGD”)利用负梯度 ∇L(W)和上一次权重的更新值V t 的线性组合来更新权重 W。学习率(learning rate)α 是负梯度的权重。动量(momentum)μ 是上一次更新值的权重。
有如下公式,根据上一次计算的更新值V t 和当前权重W t 来计算本次的更新值V t+1 和权重 Wt+1 :
 Vt+1 = μV t − α∇L(W t )
 Wt+1 = W t + V t+1

 但是在另一篇文章[Systematic evaluation of CNN advances on the ImageNet]中指出,lr最佳初值: lr = 0.01* batch_size / 256 。我自己实验过了,这个公式很合理很好用的说。

pvanet训练网络时的一些小技巧_第1张图片

2.关于AnchorBox

# C++ implementation of the proposal layer
layer {
  name: 'proposal'
  type: 'Proposal'
  bottom: 'rpn_cls_prob_reshape'
  bottom: 'rpn_bbox_pred'
  bottom: 'im_info'
  top: 'rpn_rois'
  top: 'rpn_scores'
  proposal_param {
    ratio: 0.5 ratio: 0.667 ratio: 1.0 ratio: 1.5 ratio: 2.0
    scale: 3 scale: 6 scale: 9 scale: 16 scale: 32
    base_size: 16
    feat_stride: 16
    pre_nms_topn: 12000
    post_nms_topn: 200
    nms_thresh: 0.7
    min_size: 16
  }
}
...
layer {
  name: 'rpn-data'
  type: 'Python'
  bottom: 'rpn_cls_score'
  bottom: 'gt_boxes'
  bottom: 'im_info'
  bottom: 'data'
  top: 'rpn_labels'
  top: 'rpn_bbox_targets'
  top: 'rpn_bbox_inside_weights'
  top: 'rpn_bbox_outside_weights'
  python_param {
    module: 'rpn.anchor_target_layer'
    layer: 'AnchorTargetLayer'
    param_str: "{'feat_stride': 16, 'scales': [3, 6, 9, 16, 32], 'ratios': [0.5, 0.667, 1.0, 1.5, 2.0]}"
  }
}
  train.prototxt里分别用c++和python层对proposal和rpn里的box的形状给了参数。c++的在pvanet/caffe-fast-rcnn/src/caffe/layers/proposal.cpp里,python的在pvanet/lib/rpn/generate_anchors.py里。
  其中AnchorBox的 base_size=16。经过上述文件里的计算,其形状为:
  w = scale * base_size / sqrt(ratio)
  h = w * ratio
  所以为了提高检测效果,可以用 scale 和 ratio 按需改box的形状。    

3.pvanet/lib/fast_rcnn/config.py

这个文件很重要啊,训练网络相关的设置几乎都在里面了。
【tip】
1.TRAIN.HAS_RPN = Ture
2.样本里目标大小要大于 RPN_MIN_SIZE = 16 这个参数,这对应于pvanet 进行roi pooling的特征图上目标至少有一个像素大小
3.开启水平翻转样本增强
比如:
# Minibatch size (number of regions of interest [ROIs])
__C.TRAIN.BATCH_SIZE = 128
# Overlap required between a ROI and ground-truth box in order for that ROI to
# be used as a bounding-box regression training example
__C.TRAIN.BBOX_THRESH = 0.5   
# Iterations between snapshots
__C.TRAIN.SNAPSHOT_ITERS = 10000   
# Use RPN to detect objects
__C.TRAIN.HAS_RPN = True
#IOU >= thresh: positive example
__C.TRAIN.RPN_POSITIVE_OVERLAP = 0.7
# IOU < thresh: negative example
__C.TRAIN.RPN_NEGATIVE_OVERLAP = 0.3
......

4.检测用的脚本

pvanet自带的test_net.py和demon.py都可以,但用起来不是很灵活。
我自己训练了几个目标,偷懒想把pvanet本身检测的21类和我的目标一起显示出来。不考虑效率,最简单的就是让pvanet对一张图跑2遍。
#!/usr/bin/env python

# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""Test a Fast R-CNN network on an image database."""

import _init_paths
from fast_rcnn.test import test_net
from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list
from datasets.factory import get_imdb
import caffe
import argparse
import pprint
import time, os, sys

import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse

CLASSES2 = ('__background__',
          'aeroplane', 'bicycle', 'bird', 'boat',
          'bottle', 'bus', 'car', 'cat', 'chair',
          'cow', 'diningtable', 'dog', 'horse',
          'motorbike', 'person', 'pottedplant',
          'sheep', 'sofa', 'train', 'tvmonitor')
CLASSES = ('__background__','Pillar','NoEntering','GuideArrow'
           )

NETS = {'vgg16': ('VGG16',
                  'VGG16_faster_rcnn_final.caffemodel'),
        'zf': ('ZF',
                  'ZF_faster_rcnn_final.caffemodel')}


def demo(net, net2, image_name, _t):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(net, im, _t)
    timer.toc()
    print ('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0])

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3

    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')

    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]

        class_name = cls   
        thresh=CONF_THRESH   
        inds = np.where(dets[:, -1] >= thresh)[0]
        if len(inds) == 0:
            continue

        im = im[:, :, (2, 1, 0)]
        for i in inds:
            bbox = dets[i, :4]
            score = dets[i, -1]
            print class_name,score
            ax.add_patch(
                plt.Rectangle((bbox[0], bbox[1]),
                              bbox[2] - bbox[0],
                              bbox[3] - bbox[1], fill=False,
                              edgecolor='red', linewidth=3.5)
            )
            ax.text(bbox[0], bbox[1] - 2,
                    '{:s} {:.3f}'.format(class_name, score),
                    bbox=dict(facecolor='blue', alpha=0.5),
                    fontsize=14, color='white')

            ax.set_title(('{} detections with '
                      'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                      thresh),
                     fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.draw()

# net2 detect
    scores2, boxes2 = im_detect(net2, im, _t)
    for cls_ind2, cls2 in enumerate(CLASSES2[1:]):
        cls_ind2 += 1  # because we skipped background
        cls_boxes2 = boxes2[:, 4 * cls_ind2:4 * (cls_ind2 + 1)]
        cls_scores2 = scores2[:, cls_ind2]
        dets2 = np.hstack((cls_boxes2,
                          cls_scores2[:, np.newaxis])).astype(np.float32)
        keep2 = nms(dets2, NMS_THRESH)
        dets2 = dets2[keep2, :]
        class_name = cls2
        thresh = CONF_THRESH
        inds = np.where(dets2[:, -1] >= thresh)[0]
        num= len(inds)
        if len(inds) == 0:
            continue

        im = im[:, :, (2, 1, 0)]
#        for i in inds:
#            if cls_ind2 == 1:  # skip areoplan
#                continue
#            if cls_ind2 == 3:  # bird
#                continue
#            if cls_ind2 == 19:  #train
#                continue
            bbox = dets2[i, :4]
            score = dets2[i, -1]
            print class_name,score

            ax.add_patch(
                plt.Rectangle((bbox[0], bbox[1]),
                              bbox[2] - bbox[0],
                              bbox[3] - bbox[1], fill=False,
                              edgecolor='red', linewidth=3.5)
            )
            ax.text(bbox[0], bbox[1] - 2,
                    '{:s} {:.3f}'.format(class_name, score),
                    bbox=dict(facecolor='blue', alpha=0.5),
                    fontsize=14, color='white')

            ax.set_title(('{} detections with '
                          'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                          thresh),
                         fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.draw()

def parse_args():
    """
    Parse input arguments
    """
    parser = argparse.ArgumentParser(description='Test a Fast R-CNN network')
    parser.add_argument('--gpu', dest='gpu_id', help='GPU id to use',
                        default=0, type=int)
    parser.add_argument('--def', dest='prototxt',
                        help='prototxt file defining the network',
                        default=None, type=str)
    parser.add_argument('--net', dest='caffemodel',
                        help='model to test',
                        default=None, type=str)
    parser.add_argument('--def2', dest='prototxt2',
                        help='prototxt file defining the network',
                        default=None, type=str)
    parser.add_argument('--net2', dest='caffemodel2',
                        help='model to test',
                        default=None, type=str)
    parser.add_argument('--cfg', dest='cfg_file',
                        help='optional config file', default=None, type=str)
    parser.add_argument('--wait', dest='wait',
                        help='wait until net file exists',
                        default=True, type=bool)
    parser.add_argument('--imdb', dest='imdb_name',
                        help='dataset to test',
                        default='voc_2007_test', type=str)
    parser.add_argument('--comp', dest='comp_mode', help='competition mode',
                        action='store_true')
    parser.add_argument('--set', dest='set_cfgs',
                        help='set config keys', default=None,
                        nargs=argparse.REMAINDER)
    parser.add_argument('--vis', dest='vis', help='visualize detections',
                        action='store_true')
    parser.add_argument('--num_dets', dest='max_per_image',
                        help='max number of detections per image',
                        default=100, type=int)

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_args()

    print('Called with args:')
    print(args)

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    cfg.GPU_ID = args.gpu_id

    print('Using config:')
    pprint.pprint(cfg)

    while not os.path.exists(args.caffemodel) and args.wait:
        print('Waiting for {} to exist...'.format(args.caffemodel))
        time.sleep(10)

    caffe.set_mode_cpu()                                                       # set  GPU  /  CPU  mode  
    caffe.set_device(args.gpu_id)

    net = caffe.Net(args.prototxt, args.caffemodel, caffe.TEST)
    net.name = os.path.splitext(os.path.basename(args.caffemodel))[0]

    net2 = caffe.Net(args.prototxt2, args.caffemodel2, caffe.TEST)               # load net from caffe 
    net2.name = os.path.splitext(os.path.basename(args.caffemodel2))[0]
# load img from current path
    currentdir = os.getcwd()
    imgpath = currentdir + '/data/demo/testimg/'
    im_names = []
    filelist = os.listdir(imgpath)
    for files in filelist:
        name = 'testimg/'+files
        im_names.append(name)
    #im_names = [ 'testimg/pillar/111006.jpg','testimg/pillar/153549.jpg']
    _t = {'im_preproc': Timer(), 'im_net' : Timer(), 'im_postproc': Timer(), 'misc' : Timer()}
    for im_name in im_names:
        print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
        print 'Demo for data/demo/{}'.format(im_name)
        demo(net, net2, im_name, _t)

    plt.show()

检测的效果如图:
pvanet训练网络时的一些小技巧_第2张图片
pvanet训练网络时的一些小技巧_第3张图片
pvanet训练网络时的一些小技巧_第4张图片
pvanet训练网络时的一些小技巧_第5张图片

5.绘制网络结构图
pvanet/caffe-fast-rcnn/python/draw_net.py脚本可以根据prototxt绘制网络结构模型:

python draw_net.py test.prototxt test.jpg --rankdir TB

效果如下:

你可能感兴趣的:(神经网络基础)