torchvision中已经为我们实现好了faster R-CNN模型,我们只需要调用即可。本文将对该模型做进一步的分析,使自己能够在使用过程中更加得心应手。
faster RCNN类继承了GeneralizedRCNN
类,该类是对faster RCNN模型的进一步抽象,我们首先来分析一下GeneralizedRCNN类。
class FasterRCNN(GeneralizedRCNN):
从构造函数中可以看出,GeneralizedRCNN类将faster RCNN抽象成了3部分:backbone
、rpn
、roi_heads
,外加一个对输入数据进行处理的transform
。这三部分的功能分别为:
backbone
:提取图片特征,输出feature map
rpn
:进行region proposal
roi_heads
:对roi进行分类和回归
这三部分包括transform的具体实现和细节将会在后面进行详细分析
class GeneralizedRCNN(nn.Module):
def __init__(self, backbone, rpn, roi_heads, transform):
super(GeneralizedRCNN, self).__init__()
self.transform = transform
self.backbone = backbone
self.rpn = rpn
self.roi_heads = roi_heads
# used only on torchscript mode
self._has_warned = False
下面来看GeneralizedRCNN中的具体流程:
在调用GeneralizedRCNN类的对象时,需要同时传入images
和targets
,这是因为GeneralizedRCNN类在训练时会利用targets计算loss。在我们自己定义模型时,通常只会输入images,模型会输出结果,然后由我们自己定义criterion
来计算损失,而torchvision中的faster RCNN在训练时输出的是loss。
def forward(self, images, targets=None):
首先模型会验证targets是否为None,当开启训练模式时,targets需要用来计算loss,所以不能为None?
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
如何开启训练模式?
model.train()
如果检测到为训练模式,并且targets
不为None,模型将会对targets
的数据格式进行检查。
模型要求targets
的格式为List[Dict[str, Tensor]]
(每一个Dict对象代表一张图片,Nx代表图片中ground truth的数量,'labels'
必须为一维的tensor):
[
{
'boxes':tensor(N1,4),
'labels':tensor(N1,)
},
{
'boxes':tensor(N2,4),
'labels':tensor(N2,)
},
。。。。。。
]
模型将会验证'boxes'
是否为tensor
,并且形状是否为(N,4)
if self.training:
assert targets is not None
# targets=[{'boxes': tensor(N,4), 'labels': tensor(N,)},{'boxes': tensor(), 'labels': tensor()},……]
# 'boxes'需要为tensor
for target in targets:
boxes = target["boxes"]
if isinstance(boxes, torch.Tensor):
# 'boxes'的shape需要为(N,4)
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError("Expected target boxes to be a tensor"
"of shape [N, 4], got {:}.".format(
boxes.shape))
else:
raise ValueError("Expected target boxes to be of type "
"Tensor, got {:}.".format(type(boxes)))
下面模型会记录图片的原始尺寸,因为后面在预处理中会对模型进行resize,而最后输出时还要进行后处理,恢复到原始尺寸。
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
# images=[tensor(3,H,W),tensor(3,H,W),……](batch_size个图片)
for img in images:
val = img.shape[-2:]
assert len(val) == 2
# 保存图片的(H,W)
original_image_sizes.append((val[0], val[1]))
# original_image_sizes=[(H, W), (H, W), (H, W), (H, W)]
这里需要注意的是images
的格式:List[Tensor]
。
图片的大小可以不一样,但是必须要3通道。
[
tensor(3,H1,W1),
tensor(3,H2,W2),
。。。。。。
]
下面模型将对images
和targets
进行预处理,预处理的具体细节我们将在后面做详细讨论。
images, targets = self.transform(images, targets)
上面模型已经对targets
的格式进行了验证,下面模型将会对targets
的内容进行验证。
’boxes‘
中每一个box的具体内容为[xmin,ymin,xmax,ymax]
,模型会验证xmin
、ymin
是不是小于等于xmax
、ymax
if targets is not None:
for target_idx, target in enumerate(targets): # {'boxes': tensor(N,4), 'labels': tensor(N,)}
boxes = target["boxes"] # [xmin,ymin,xmax,ymax]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] # xmin,ymin<=xmax,ymax
if degenerate_boxes.any():
# print the first degenerate box
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError("All bounding boxes should have positive height and width."
" Found invalid box {} for target at index {}."
.format(degen_bb, target_idx))
上面的内容仅仅是对数据进行验证和预处理,如果数据通过验证,下面将进入核心阶段
1、使用backbone
提取images
特征
features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)])
2、进行rpn
进行region proposal
proposals, proposal_losses = self.rpn(images, features, targets)
3、使用roi_heads
对roi
进行分类和回归
detections, detector_losses = self.roi_heads(
features, proposals, images.image_sizes, targets)
核心过程之后,已经得到检测结果,下面将对检测结果进行后处理
detections = self.transform.postprocess(
detections, images.image_sizes, original_image_sizes)
核心过程和后处理的细节也会在后面进行分析
最后,模型将统计roi_heads
和rpn
阶段产生的loss,并将loss和检测结果输出
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
return self.eager_outputs(losses, detections)
这里再return语句中调用了eager_outputs()
函数,我们来看一下该函数的具体内容:
def eager_outputs(self, losses, detections):
if self.training:
return losses
return detections
也就是说,在训练模式下会返回loss,而测试模式下会返回检测结果。
上面是GeneralizedRCNN类中对Faster RCNN整个过程的抽象,FasterRCNN类在继承GeneralizedRCNN类时并没有重写forward函数。