one stage 的RFBNet在保证速度的前提下,也有着不错的精度,所以拿来训练kaggle上的RSNA。这边主要介绍下对RFBnet源码修改支持RSNA的训练,如果想看关于RSNA数据分析的,可以去看kaggle上的kernels。
RSNA跟常见的检测数据集(COCO,VOC,BDD100K,CITYSCAPE等)不一样的一个地方就是,图片中可能不存在标注,也就是说不存在foreground,我就隐隐觉得源码可能不支持这种情况,果然写完dataloader之后报错了,然后就需要修改源码了。
大家都有自己的风格,主要就是:
1.用SimpleITK读dicom
2.当前图像没有标注时,load annotation返回 np.zeros((1, 5))
源码会根据foreground的数量,按一定比例取一些background,但是如果没有foreground,background也没有,算正负样本分类的交叉熵就会报错。
我添加了一段逻辑,如果没有foreground,就选择4个background进行计算,对应下面代码55-58。
def forward(self, predictions, priors, targets):
"""Multibox Loss
Args:
predictions (tuple): A tuple containing loc preds, conf preds,
and prior boxes from SSD net.
conf shape: torch.size(batch_size,num_priors,num_classes)
loc shape: torch.size(batch_size,num_priors,4)
priors shape: torch.size(num_priors,4)
ground_truth (tensor): Ground truth boxes and labels for a batch,
shape: [batch_size,num_objs,5] (last idx is the label).
"""
loc_data, conf_data = predictions
priors = priors
num = loc_data.size(0)
num_priors = (priors.size(0))
num_classes = self.num_classes
# match priors (default boxes) and ground truth boxes
loc_t = torch.Tensor(num, num_priors, 4)
conf_t = torch.LongTensor(num, num_priors)
for idx in range(num):
truths = targets[idx][:, :-1].data
labels = targets[idx][:, -1].data
defaults = priors.data
match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx)
if GPU:
loc_t = loc_t.cuda()
conf_t = conf_t.cuda()
# wrap targets
loc_t = Variable(loc_t, requires_grad=False)
conf_t = Variable(conf_t, requires_grad=False)
pos = conf_t > 0
# Localization Loss (Smooth L1)
# Shape: [batch,num_priors,4]
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
loc_p = loc_data[pos_idx].view(-1, 4)
loc_t = loc_t[pos_idx].view(-1, 4)
loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
# Compute max conf across batch for hard negative mining
batch_conf = conf_data.view(-1, self.num_classes)
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
# Hard Negative Mining
loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
loss_c = loss_c.view(num, -1)
_, loss_idx = loss_c.sort(1, descending=True)
_, idx_rank = loss_idx.sort(1)
num_pos = pos.long().sum(1, keepdim=True)
constant_min = torch.ones(num_pos.shape, dtype=torch.int64) * 4
neg_min = torch.max(self.negpos_ratio * num_pos, constant_min.cuda())
num_neg = torch.clamp(neg_min, max=pos.size(1) - 1)
neg = idx_rank < num_neg.expand_as(idx_rank)
# Confidence Loss Including Positive and Negative Examples
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes)
targets_weighted = conf_t[(pos + neg).gt(0)]
loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
N = max(num_pos.data.sum().float(), 1)
loss_l /= N
loss_c /= N
return loss_l, loss_c