keras faster r-cnn源代码解析(一)——训练过程

引言:

    开始看faster r-cnn的过程是这样的,想看自然场景文本检测,然后查到了CTPN,CTPN是基于Fast R-CNN的RPN进行的改进,然后就开始看Faster r-cnn,大牛写的论文根本看不懂,看了一遍论文只能朦朦胧胧有点印象这东西大概是搞什么的,遇到没见过的名词就查,刚开始有以下几个,

Q1:Regin proposal中的proposal:即比较可能是物体的一个区域

Q2:R-CNN:Region based CNN Regions with CNN features

Q3:Faster R-CNN:CNN with RPN(Reigon Proposal Networks)使用CNN计算Proposal用于提取特征的卷积网络同样能用于查找proposal

Q4:感受野:特征能代表的原图的像素区域大小,例如说感受野是228,即映射之后的一个特征,代表原图的228*288像素大小的区域。

    看了第二遍论文,比第一遍好一些,但仍然有大量疑惑,特别是具体是如何做的。怎么办?查。先后看了以下几个链接:

https://zhuanlan.zhihu.com/p/31426458一文读懂Faster RCNN

我自己的实践表明,如果一点都不懂,仅凭这一文根本读不懂。如果源代码读完了,确实读这一文就能读懂了。但不是因为读这一文读懂了。

http://www.telesens.co/2018/03/11/object-detection-and-classification-using-r-cnns/Object Detection and Classification using R-CNNs,外国人写的,应该是最详细的一篇,比上面那个一文读懂还要详细,但问题主要有以下几点:英文的,名词用的不太一样,代码是pytorch的,

https://www.jianshu.com/p/71fbb3251cbfFaster R-CNN原理详解(基于keras代码)

这个涉及到代码的讲解了,但是没有写完,而且代码和原理一起讲,有点乱

http://geyao1995.com/archives/page/3/

这个博客里有对keras版本的详细讲解,相对来说比较详细了,但缺点是老子自己吭哧吭哧看了3天刚把代码看完了才发现这个博客,早发现的话估计1天就看完了。

    代码看完了再回过头来看上面那些原理讲解就很清晰了,看原论文也知道作者在讲什么了,因此如果有质量比较好的源代码应该先源代码,然后再结合着看详解。

   解析代码:https://github.com/yiqisetian/Keras-FasterRCNN,这个代码来源于https://github.com/you359/Keras-FasterRCNN

里面有一些韩语注释,难道说是棒子写的?

  完整的能运行的代码目录如下:

keras faster r-cnn源代码解析(一)——训练过程_第1张图片

    其中config.pickle是运行脚本之后保存的配置文件,vgg16的权重文件需要自己下载,VOC2012也需要自己下载,test.py没有用,其他都是原始的代码。

    train_frcnn.py是用于训练的,keras_frcnn文件夹里面存放的是实现faster r-cnn所用到的各种类和方法。

   一、引入模块

from __future__ import division
import random
import pprint
import sys
import time
import numpy as np
from optparse import OptionParser#python命令解析模块,注意:从2.7版本后不再使用:optparse模块不推荐使用,python不再更新该模块,后续的发展将推荐使用argparse模块。
import pickle
import os

import tensorflow as tf
from keras import backend as K
from keras.optimizers import Adam, SGD, RMSprop
from keras.layers import Input
from keras.models import Model
from keras_frcnn import config, data_generators
from keras_frcnn import losses as losses
import keras_frcnn.roi_helpers as roi_helpers
from keras.utils import generic_utils
from keras.callbacks import TensorBoard

二、解析命令参数

# tensorboard 
def write_log(callback, names, logs, batch_no):
    for name, value in zip(names, logs):
        summary = tf.Summary()
        summary_value = summary.value.add()
        summary_value.simple_value = value
        summary_value.tag = name
        callback.writer.add_summary(summary, batch_no)
        callback.writer.flush()

sys.setrecursionlimit(40000)#设置python最大递归深度

parser = OptionParser()#声明一个OptionParser对象

parser.add_option("-p", "--path", dest="train_path", help="Path to training data.")#训练集路径
parser.add_option("-o", "--parser", dest="parser", help="Parser to use. One of simple or pascal_voc",
                  default="pascal_voc")#指定数据集,pascal_voc或者simple
parser.add_option("-n", "--num_rois", dest="num_rois", help="Number of RoIs to process at once.", default=300)
parser.add_option("--network", dest="network", help="Base network to use. Supports vgg or resnet50.", default='vgg')
parser.add_option("--hf", dest="horizontal_flips", help="Augment with horizontal flips in training. (Default=false).", action="store_true", default=False)
parser.add_option("--vf", dest="vertical_flips", help="Augment with vertical flips in training. (Default=false).", action="store_true", default=False)
parser.add_option("--rot", "--rot_90", dest="rot_90", help="Augment with 90 degree rotations in training. (Default=false).",
                  action="store_true", default=False)
parser.add_option("--num_epochs", dest="num_epochs", help="Number of epochs.", default=2000)
parser.add_option("--config_filename", dest="config_filename",
                  help="Location to store all the metadata related to the training (to be used when testing).",
                  default="config.pickle")
parser.add_option("--output_weight_path", dest="output_weight_path", help="Output path for weights.", default='./model_frcnn.hdf5')
parser.add_option("--input_weight_path", dest="input_weight_path", help="Input path for weights. If not specified, will try to load default weights provided by keras.")

(options, args) = parser.parse_args()#解析参数

if not options.train_path:   # 要给出训练路径
    parser.error('Error: path to training data must be specified. Pass --path to command line')

if options.parser == 'pascal_voc':#使用pascal_voc数据集
    from keras_frcnn.pascal_voc_parser import get_data
elif options.parser == 'simple':
    from keras_frcnn.simple_parser import get_data
else:
    raise ValueError("Command line option parser must be one of 'pascal_voc' or 'simple'")

# pass the settings from the command line, and persist them in the config object
C = config.Config()#使用keras_frcnn/config.py生成一个Config对象
#用于数据增强
C.use_horizontal_flips = bool(options.horizontal_flips)
C.use_vertical_flips = bool(options.vertical_flips)
C.rot_90 = bool(options.rot_90)

C.model_path = options.output_weight_path
C.num_rois = int(options.num_rois)#Number of RoIs to process at once.
#选择基础网络模型
if options.network == 'vgg':
    C.network = 'vgg'
    from keras_frcnn import vgg as nn
elif options.network == 'resnet50':
    from keras_frcnn import resnet as nn
    C.network = 'resnet50'
elif options.network == 'xception':
    from keras_frcnn import xception as nn
    C.network = 'xception'
elif options.network == 'inception_resnet_v2':
    from keras_frcnn import inception_resnet_v2 as nn
    C.network = 'inception_resnet_v2'
else:
    print('Not a valid model')
    raise ValueError
#选择基础网络模型权重的路径
# check if weight path was passed via command line
if options.input_weight_path:
    C.base_net_weights = options.input_weight_path
else:
    # set the path to weights based on backend and model
    C.base_net_weights = nn.get_weight_path()#加载网络权重

三、解析数据

  这段最重要的就是get_data方法,默认情况下是引用keras_frcnn/pascal_voc_parser.py中的get_data方法,输入和输出详解见注释。解析之后得到all_imgs, classes_count, class_mapping三个变量,all_imgs用于存储所有的图片信息(注意这里是图片的信息,不是图片的像素),classes_count用于存放类别的计数(是图片中标注框的计数,因此要远远多于图片的数量),class_mapping用于存放样本类别是数字的关系,即将样本类别字符串转换为数字以便于形成one_hot标签。

# parser
#输入:
#数据集所在路径,这个是数据集所在路径,在路径下要包含VOC2012文件夹
#输出:
#all_imgs的每一项都包含['filepath','width','height','imageid','imageset','bbox'{'class_name','x1','y1','x2','y2','difficult'}]imageset表示是训练集还是测试集
#示例如下:
'''
all_img_data[0] = {'width': 500, 'height': 500,
                 'bboxes': [{'y2': 500, 'y1': 27, 'x2': 183, 'x1': 20, 'class': 'person', 'difficult': False},
                            {'y2': 500, 'y1': 2, 'x2': 249, 'x1': 112, 'class': 'person', 'difficult': False},
                            {'y2': 490, 'y1': 233, 'x2': 376, 'x1': 246, 'class': 'person', 'difficult': False},
                            {'y2': 468, 'y1': 319, 'x2': 356, 'x1': 231, 'class': 'chair', 'difficult': False},
                            {'y2': 450, 'y1': 314, 'x2': 58, 'x1': 1, 'class': 'chair', 'difficult': True}], 'imageset': 'test',
                 'filepath': './datasets/VOC2007/JPEGImages/000910.jpg'}
'''
#classes_count存放每类的标注框的数量
#e.g.{'sheep': 8, 'horse': 5, 'bg': 0, 'bicycle': 7, 'motorbike': 15, 'cow': 6, 'car': 34, 'aeroplane': 2, 'dog': 4, 'bus': 4, 'cat': 6, 'person': 113, 'train': 7, 'diningtable': 4, 'bottle': 3, 'sofa': 9, 'pottedplant': 7, 'tvmonitor': 7, 'chair': 27, 'bird': 6, 'boat': 7}
#classes_mapping存放样本类别数字和字符串的对应关系,例如'bg':0
all_imgs, classes_count, class_mapping = get_data(options.train_path)

# bg 添加一个背景类
if 'bg' not in classes_count:
    classes_count['bg'] = 0
    class_mapping['bg'] = len(class_mapping)

C.class_mapping = class_mapping

inv_map = {v: k for k, v in class_mapping.items()}#将class_mapping转换为字典格式{0:’bg'}

四、保存配置文件,划分训练集、验证集和测试集

print('Training images per class:')
pprint.pprint(classes_count)
print('Num classes (including bg) = {}'.format(len(classes_count)))

config_output_filename = options.config_filename

with open(config_output_filename, 'wb') as config_f:
    pickle.dump(C, config_f)#保存文件
    print('Config has been written to {}, and can be loaded when testing to ensure correct results'.format(config_output_filename))

random.shuffle(all_imgs)

num_imgs = len(all_imgs)
#划分训练集
train_imgs = [s for s in all_imgs if s['imageset'] == 'train']
val_imgs = [s for s in all_imgs if s['imageset'] == 'val']
test_imgs = [s for s in all_imgs if s['imageset'] == 'test']

print('Num train samples {}'.format(len(train_imgs)))
print('Num val samples {}'.format(len(val_imgs)))
print('Num test samples {}'.format(len(test_imgs)))

五、声明数据的generator对象
 

#作用:导入数据,自己生成一个data_generators回调函数,使用yield返回会得到一个generator对象
#输入:
#all_img_data,待处理的文件列表,里面存的对象如下:
#all_imgs的每一项都包含['filepath','width','height','imageid','imageset','bbox'{'class_name','x1','y1','x2','y2','difficult'}]imageset表示是训练集还是测试集
'''
all_img_data[0] = {'width': 500, 'height': 500,
                 'bboxes': [{'y2': 500, 'y1': 27, 'x2': 183, 'x1': 20, 'class': 'person', 'difficult': False},
                            {'y2': 500, 'y1': 2, 'x2': 249, 'x1': 112, 'class': 'person', 'difficult': False},
                            {'y2': 490, 'y1': 233, 'x2': 376, 'x1': 246, 'class': 'person', 'difficult': False},
                            {'y2': 468, 'y1': 319, 'x2': 356, 'x1': 231, 'class': 'chair', 'difficult': False},
                            {'y2': 450, 'y1': 314, 'x2': 58, 'x1': 1, 'class': 'chair', 'difficult': True}], 'imageset': 'test',
                 'filepath': './datasets/VOC2007/JPEGImages/000910.jpg'}
'''
#class_count所有图片中标注框的类别计数
#C配置文件是否翻转、anchor比例和大小、输入图片的缩放目标大小等等
#img_length_calc_function,是原始图片和特征图大小的对应关系,例如vgg由于采用了4个pooling层且卷积层padding为1,不缩放图片,因此vgg得到的特征图的大小是原图的16分之一
#而resnet的网路结构更加复杂,从最开始的图片到特征图的对应关系就较为复杂,这个函数和具体的网路有关,这里以vgg为例
#backend指定神经网络后端,就是th和tf两种,
#mode只有train和非train两种,train会再次shuffle数据,并且可以使用数据增强,而val和test则不需要这两个操作。
#输出:yield np.copy(x_img), [np.copy(y_rpn_cls), np.copy(y_rpn_regr)], img_data_aug
#x_img:(增强后的)原始图像
#[np.copy(y_rpn_cls), np.copy(y_rpn_regr)]:
#y_rpn_cls:所有anchor的valid和overlap标记
#y_rpn_regr:所有anchor的overlap标记和每个GTbox对应的最优的回归参数
#输出的结果是包含所有anchor的结果,其中valid为1的是neg和pos类的,0是中性的,overlap为1的是有交叉的,0是无交叉的,
#因此有交叉的、valid为1 的就是pos的anchor,在对应的位置记录了对应最优的回归参数
#y_rpn_cls->(1*37*37*18)前9个为valid后9个为overlap,y_rpn_regr->(1*37*37*72),前36个为overlap标记(实际上只有9个,是复制了4份),后36个为bestreg
#img_data_aug,增强后的图像信息
#主要是为了获得[np.copy(y_rpn_cls), np.copy(y_rpn_regr)]
data_gen_train = data_generators.get_anchor_gt(train_imgs, classes_count, C, nn.get_img_output_length, K.image_dim_ordering(), mode='train')
#nn.get_img_output_length获取网络输出大小,例如VGG有4个pooling层,因此特征图的宽、高是原图的16分之一
data_gen_val = data_generators.get_anchor_gt(val_imgs, classes_count, C, nn.get_img_output_length, K.image_dim_ordering(), mode='val')
data_gen_test = data_generators.get_anchor_gt(test_imgs, classes_count, C, nn.get_img_output_length, K.image_dim_ordering(), mode='val')

六、构建网络

if K.image_dim_ordering() == 'th':
    input_shape_img = (3, None, None)
else:
    input_shape_img = (None, None, 3)#(W,H,C)

# input placeholder 정의
img_input = Input(shape=input_shape_img)
roi_input = Input(shape=(None, 4))

# base network(feature extractor) 정의 (resnet, VGG, Inception, Inception Resnet V2, etc)
shared_layers = nn.nn_base(img_input, trainable=True)#以VGG为例,输出的feature map为(37*37*512) 总的输出是(none,37,37,512)

# define the RPN, built on the base layers
# RPN 정의
num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)#anchor的数量,9个
#输出:rpn=[x_class, x_regr, base_layers](37,37,9+36+512)
#x_class:(37,37,9),每个位置上的每个anchor的二分类的结果
#x_regr:(37,37,36=9*4),激活函数采用线性回归,得到每个anchor的回归参数
#base_layers:(37,37,512),vgg的feature map输出
rpn = nn.rpn(shared_layers, num_anchors)

# detection network 정의
'''
网络的输入:
base_layer: 也就是前面的Vgg网络的输出,同样其shape为(37 * 37 * 512 )
input_rois: 就是RPN网络提取的RoI.anchor在feature map上的对应的box
num_rois: 前面R-CNN和fast R-CNN通过Slective search提取的RoI的数量大约是2000个,但是由于RPN网络提取的RoI是有目的性的,仅仅提取其中不超过300个就好.在代码本keras版本的代码中,默认设置的时32个,这个参数可以根据实际情况调整.
nb_classes: 指的数据集中所有的类别数,有20个前景类别,另外 加一个背景,总共21类
#输出:[out_class, out_regr]
#out_class:anchor的分类
#out_regr:anchor的回归参数
#分类和回归用的是同一个网络,输入都是anchor
'''
classifier = nn.classifier(shared_layers, roi_input, C.num_rois, nb_classes=len(classes_count), trainable=True)

model_rpn = Model(img_input, rpn[:2])#input->img_input ,out_put->rpn[:2]->[x_class:(37,37,9),x_regr:(37,37,36=9*4)]
model_classifier = Model([img_input, roi_input], classifier)#input->[img_input, roi_input],多输入,输出classifier->[out_class, out_regr]

# this is a model that holds both the RPN and the classifier, used to load/save weights for the models
model_all = Model([img_input, roi_input], rpn[:2] + classifier)
try:
    # load_weights by name
    # some keras application model does not containing name
    # for this kinds of model, we need to re-construct model with naming
    print('loading weights from {}'.format(C.base_net_weights))
    model_rpn.load_weights(C.base_net_weights, by_name=True)
    model_classifier.load_weights(C.base_net_weights, by_name=True)
except:
    print('Could not load pretrained model weights. Weights can be found in the keras application folder \
        https://github.com/fchollet/keras/tree/master/keras/applications')

optimizer = Adam(lr=1e-5)
optimizer_classifier = Adam(lr=1e-5)
model_rpn.compile(optimizer=optimizer, loss=[losses.rpn_loss_cls(num_anchors), losses.rpn_loss_regr(num_anchors)])
model_classifier.compile(optimizer=optimizer_classifier, loss=[losses.class_loss_cls, losses.class_loss_regr(len(classes_count)-1)], metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
model_all.compile(optimizer='sgd', loss='mae')

网络结构如下图所示:

原始图像Image有GTbox和类别标签,记为(X,Y),然后从Image中提取候选Anchors以及Bounding Box Regression(后面讲解),使用基础网络(例如VGG)提取特征图(feature map),将特征图和候选Anchors作为(X',Y')输入到RPN网络中,再次生成了一个Anchor是分类和Bounding Box Regression,然后在feature map再选取一次Anchor,这次的Anchor就成为ROI,然后将这些ROI做一个Pooling,保证这些ROI都是一样大的(就相当于在使用VGG之前将图像都缩放成一样大小),然后训练分类器,最后完成RPN和分类器的训练。

七、设置日志参数

# Tensorboard log폴더 생성
log_path = './logs'
if not os.path.isdir(log_path):
    os.mkdir(log_path)

# Tensorboard log모델 연결
callback = TensorBoard(log_path)
callback.set_model(model_all)

epoch_length = 1000
num_epochs = int(options.num_epochs)
iter_num = 0
train_step = 0

losses = np.zeros((epoch_length, 5))
rpn_accuracy_rpn_monitor = []
rpn_accuracy_for_epoch = []
start_time = time.time()

best_loss = np.Inf

class_mapping_inv = {v: k for k, v in class_mapping.items()}#{0:’bg'}
print('Starting training')

八、开始训练

for epoch_num in range(num_epochs):

    progbar = generic_utils.Progbar(epoch_length)   # keras progress bar 사용
    print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))

    while True:
        # try:
        # mean overlapping bboxes 출력
        if len(rpn_accuracy_rpn_monitor) == epoch_length and C.verbose:
            mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor)
            rpn_accuracy_rpn_monitor = []
            print('Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'.format(mean_overlapping_bboxes, epoch_length))
            if mean_overlapping_bboxes == 0:
                print('RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.')

        #输出:yield np.copy(x_img), [np.copy(y_rpn_cls), np.copy(y_rpn_regr)], img_data_aug
        #网络训练的输入是图片,和找到的最优的anchor,不是GTbox
        X, Y, img_data = next(data_gen_train)
        #
        loss_rpn = model_rpn.train_on_batch(X, Y)#Scalar training loss (if the model has a single output and no metrics) or list of scalars (if the model has multiple outputs and/or metrics). 
        write_log(callback, ['rpn_cls_loss', 'rpn_reg_loss'], loss_rpn, train_step)
        P_rpn = model_rpn.predict_on_batch(X)#获取预测的Y值,即[np.copy(y_rpn_cls), np.copy(y_rpn_regr)]
        #筛选出ROI,ROI指的是在feature map中合法box中具有最大概率包含物体的box,删除重叠率较高的box之后剩下来的box(xxx,4)
        R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C, K.image_dim_ordering(), use_regr=True, overlap_thresh=0.7, max_boxes=300)
        # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
        #img_data,图片的信息,R候选ROI
        #输出:
        #X2:#选取的iou大于0.7的roi,这里用X2是为了和上面的X相区别,
        #Y1:对应的类别序号(1,xxx,21),类别标签是one_hot
        #Y2:[np.array(y_class_regr_label),np.array(y_class_regr_coords)]包含对应的类别的标签和回归参数,类别标签是one_hot的
        #IoUs:用于调试的,没有用
        X2, Y1, Y2, IouS = roi_helpers.calc_iou(R, img_data, C, class_mapping)

        if X2 is None:
            rpn_accuracy_rpn_monitor.append(0)
            rpn_accuracy_for_epoch.append(0)
            continue

        # sampling positive/negative samples
        neg_samples = np.where(Y1[0, :, -1] == 1)#背景,最后一项代表背景分类
        pos_samples = np.where(Y1[0, :, -1] == 0)#非背景

        if len(neg_samples) > 0:
            neg_samples = neg_samples[0]
        else:
            neg_samples = []

        if len(pos_samples) > 0:
            pos_samples = pos_samples[0]
        else:
            pos_samples = []

        rpn_accuracy_rpn_monitor.append(len(pos_samples))
        rpn_accuracy_for_epoch.append((len(pos_samples)))

        if C.num_rois > 1:
            if len(pos_samples) < C.num_rois//2:#选取一些正例样本和一些反例样本,共300个,每类约150个
                selected_pos_samples = pos_samples.tolist()
            else:
                selected_pos_samples = np.random.choice(pos_samples, C.num_rois//2, replace=False).tolist()
            try:
                selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=False).tolist()
            except:
                selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=True).tolist()

            sel_samples = selected_pos_samples + selected_neg_samples
        else:
            # in the extreme case where num_rois = 1, we pick a random pos or neg sample
            selected_pos_samples = pos_samples.tolist()
            selected_neg_samples = neg_samples.tolist()
            if np.random.randint(0, 2):
                sel_samples = random.choice(neg_samples)
            else:
                sel_samples = random.choice(pos_samples)
        #X用于生成share_layers,X2是挑选出来的roi
        loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])
        write_log(callback, ['detection_cls_loss', 'detection_reg_loss', 'detection_acc'], loss_class, train_step)
        train_step += 1

        losses[iter_num, 0] = loss_rpn[1]
        losses[iter_num, 1] = loss_rpn[2]

        losses[iter_num, 2] = loss_class[1]
        losses[iter_num, 3] = loss_class[2]
        losses[iter_num, 4] = loss_class[3]

        iter_num += 1

        progbar.update(iter_num, [('rpn_cls', np.mean(losses[:iter_num, 0])), ('rpn_regr', np.mean(losses[:iter_num, 1])),
                                  ('detector_cls', np.mean(losses[:iter_num, 2])), ('detector_regr', np.mean(losses[:iter_num, 3]))])

        if iter_num == epoch_length:
            loss_rpn_cls = np.mean(losses[:, 0])
            loss_rpn_regr = np.mean(losses[:, 1])
            loss_class_cls = np.mean(losses[:, 2])
            loss_class_regr = np.mean(losses[:, 3])
            class_acc = np.mean(losses[:, 4])

            mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
            rpn_accuracy_for_epoch = []

            if C.verbose:
                print('Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes))
                print('Classifier accuracy for bounding boxes from RPN: {}'.format(class_acc))
                print('Loss RPN classifier: {}'.format(loss_rpn_cls))
                print('Loss RPN regression: {}'.format(loss_rpn_regr))
                print('Loss Detector classifier: {}'.format(loss_class_cls))
                print('Loss Detector regression: {}'.format(loss_class_regr))
                print('Elapsed time: {}'.format(time.time() - start_time))

            curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
            iter_num = 0
            start_time = time.time()

            write_log(callback,
                      ['Elapsed_time', 'mean_overlapping_bboxes', 'mean_rpn_cls_loss', 'mean_rpn_reg_loss',
                       'mean_detection_cls_loss', 'mean_detection_reg_loss', 'mean_detection_acc', 'total_loss'],
                      [time.time() - start_time, mean_overlapping_bboxes, loss_rpn_cls, loss_rpn_regr,
                       loss_class_cls, loss_class_regr, class_acc, curr_loss],
                      epoch_num)

            if curr_loss < best_loss:
                if C.verbose:
                    print('Total loss decreased from {} to {}, saving weights'.format(best_loss,curr_loss))
                best_loss = curr_loss
                model_all.save_weights(C.model_path)

            break

        # except Exception as e:
        #     print('Exception: {}'.format(e))
        #     continue

前面主要部分已经添加了注释,后面的都是和数据显示有关,和核心内容无关

主要代码是以下几行:

loss_rpn = model_rpn.train_on_batch(X, Y)#训练RPN

P_rpn = model_rpn.predict_on_batch(X)#获取预测的Y值,即[np.copy(y_rpn_cls), np.copy(y_rpn_regr)]

相当于上图的Y’

#筛选出ROI,ROI指的是在feature map中合法box中具有最大概率包含物体的box,删除重叠率较高的box之后剩下来的box(xxx,4)
        R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C, K.image_dim_ordering(), use_regr=True, overlap_thresh=0.7, max_boxes=300)

在feature map上计算ROI(和在原图上计算候选Anchor或叫Proposal类似)

 X2, Y1, Y2, IouS = roi_helpers.calc_iou(R, img_data, C, class_mapping)

训练分类器

loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

你可能感兴趣的:(人工智能)