Pytorch版Faster R-CNN 源码分析+方法流程详解——训练篇

Pytorch版Faster R-CNN 源码分析+方法流程详解——训练篇

继demo篇之后,继续分析faster rcnn的训练过程,继续对Pytorch版的源码进行分析。

1、参考文章/博客/论文/源码

论文:https://arxiv.org/pdf/1506.01497.pdf
源码:https://github.com/jwyang/faster-rcnn.pytorch

2、环境配置

见demo篇faster rcnn检测阶段源码分析

3、数据集及模型

以VGG16作为Backbone网络模型,数据集采用PASCAL VOC2007。

5、faster rcnn方法训练流程分析(训练阶段)

脚本文件:trainval_net.py
1、def parse_args():
参数传递函数,定义网络训练所需的相关参数,方便使用命令传递参数。
Pytorch版Faster R-CNN 源码分析+方法流程详解——训练篇_第1张图片
2、数据迭代器sampler类的定义及实现
该类继承Sampler类,在torch.utils.data.DataLoader构建数据集迭代器时使用。
sampler类重写了__iter__与__len__函数,将训练数据随机打乱,根据batch size的大小返回索引迭代器。
Pytorch版Faster R-CNN 源码分析+方法流程详解——训练篇_第2张图片
3、def combined_roidb(): 训练数据组织方法。

文件:lib/roi_data_layer/roidb.py
在这里插入图片描述
四个返回值分别表示的含义
**imdb:**表示根据voc_2007_trainval定义的pascal voc类对象,该类继承python的imdb类。
**roidb:**表示图像相关的信息,dict类型,包含boxes、gt_classes、gt_overlaps、image_id、image_path、width、height等图像本身和目标检测框标注的相关信息。
**ratio_list:**表示根据图像的宽高比排序后的ratio_list。
**ratio_index:**表示根据图像宽高比排序后的list对应的原始图像索引image_index。
在combined_roidb()方法中,通过调用get_roidb()方法获取图像相应的roidb信息。
Pytorch版Faster R-CNN 源码分析+方法流程详解——训练篇_第3张图片
文件:lib/datasets/imdb.py
调用append_flipped_images()方法对图像roidb中的box进行水平翻转,注意此时仅对box进行翻转,其中gt_overlaps和gt_classes与原始roidb相同,将dlipped翻转标记设置为True表示翻转后的图像。image_index * 2表示将图片索引信息复制一遍。
Pytorch版Faster R-CNN 源码分析+方法流程详解——训练篇_第4张图片
文件:lib/roi_data_layer/roidb.py
prepare_roidb()方法对每张图片的roidb进行信息扩充,添加id,路径,宽高,box类别信息等,将所有图像的尺寸序列化存储为pkl文件,方便再次运行时读取。
Pytorch版Faster R-CNN 源码分析+方法流程详解——训练篇_第5张图片
文件:lib/roi_data_layer/roidb.py
filter_roidb()方法对每张图像的box数量进行检查,在训练阶段剔除不含box目标检测框的图像及其roidb。
Pytorch版Faster R-CNN 源码分析+方法流程详解——训练篇_第6张图片
文件:lib/roi_data_layer/roidb.py
rank_roidb_ratio()方法检查图像的宽高比,将宽高比大于2或者小于0.5的图像的裁剪标志设置为True,并将宽高比更新为最大值或最小值,对训练图像进行裁剪时需要。将ratio从小到大排序,返回排序索引及排序后的ratio_list。
Pytorch版Faster R-CNN 源码分析+方法流程详解——训练篇_第7张图片
4、构建DataLoader可读的数据集
文件:lib/roi_data_layer/roibatchLoader.py
文件:lib/roi_data_layer/minibatch.py
定义roibatchLoader子类,该类继承自Pytorch的data.Dataset类,重写了__getitem__(self, index)函数和__len__(self)函数。

在该类的构造函数中,对排序后的图像宽高比根据batch_size的大小进行分段,保证每个batchsize中的图像具有相同大小的宽高比,对于faster rcnn训练PASCAL VOC数据集来说,一个batchsize只有一幅图像。

在__getitem__函数中,根据索引取出对应的roidb,并根据取出的roidb调用get_minibatch()方法构建blobs,此处的blobs为dict类型,包含data、gt_boxes、im_info、img_id四个属性。
data:表示图像本身像素信息,[1,w,h,c]结构,RGB->BGR,flipped为True的图像进行水平翻转。
gt_boxes:n×5结构的矩阵,n表示一幅图像中box的个数,前4列表示表示box的左上右下角坐标,最后一个列表示box的类别索引。(x1, y1, x2, y2, cls)
im_info:数组类型,1×3数组,分别存储了图像的宽度、高度、尺度缩放比例。
img_id:等同于roidb中的img_id属性,表示图像的索引id。
Pytorch版Faster R-CNN 源码分析+方法流程详解——训练篇_第8张图片
文件:lib/roi_data_layer/roibatchLoader.py
根据返回的blobs信息对需要进行裁剪的图像进行裁剪,如果图像的ratio大于2或者小于0.5则需要裁剪,裁剪后box的横坐标更新可能变为负值,或超出裁剪后的宽度范围,将超出边界的box做加紧处理,box的横坐标将变为边界值。
Pytorch版Faster R-CNN 源码分析+方法流程详解——训练篇_第9张图片
根据图像的ratio值对宽度或者高度向上取整,对图像进行相应的padding。
检查boundig box。
将data维度进行转换,将通道数维度提前,(3, data_height, data_width)
返回值包括padding_data, im_info, gt_boxes_padding, num_boxes
padding_data:表示边缘取整填充后的图像像素值。
im_info:表示裁剪填充后图像的宽高及缩放尺度。
gt_boxes_padding:表示bounding box信息。
num_boxes:表示bounding box的数量。
Pytorch版Faster R-CNN 源码分析+方法流程详解——训练篇_第10张图片

你可能感兴趣的:(python,深度学习,计算机视觉)