使用RFBNet训练kaggle RSNA数据检测胸片的肺炎

one stage 的RFBNet在保证速度的前提下,也有着不错的精度,所以拿来训练kaggle上的RSNA。这边主要介绍下对RFBnet源码修改支持RSNA的训练,如果想看关于RSNA数据分析的,可以去看kaggle上的kernels。

数据集介绍

RSNA跟常见的检测数据集(COCO,VOC,BDD100K,CITYSCAPE等)不一样的一个地方就是,图片中可能不存在标注,也就是说不存在foreground,我就隐隐觉得源码可能不支持这种情况,果然写完dataloader之后报错了,然后就需要修改源码了。

源码修改

1.自己写个支持RSNA的dataloader

大家都有自己的风格,主要就是:

1.用SimpleITK读dicom

2.当前图像没有标注时,load annotation返回 np.zeros((1, 5))

2.修改multibox_loss.py

源码会根据foreground的数量,按一定比例取一些background,但是如果没有foreground,background也没有,算正负样本分类的交叉熵就会报错。

我添加了一段逻辑,如果没有foreground,就选择4个background进行计算,对应下面代码55-58。

    def forward(self, predictions, priors, targets):
        """Multibox Loss
        Args:
            predictions (tuple): A tuple containing loc preds, conf preds,
            and prior boxes from SSD net.
                conf shape: torch.size(batch_size,num_priors,num_classes)
                loc shape: torch.size(batch_size,num_priors,4)
                priors shape: torch.size(num_priors,4)

            ground_truth (tensor): Ground truth boxes and labels for a batch,
                shape: [batch_size,num_objs,5] (last idx is the label).
        """

        loc_data, conf_data = predictions
        priors = priors
        num = loc_data.size(0)
        num_priors = (priors.size(0))
        num_classes = self.num_classes

        # match priors (default boxes) and ground truth boxes
        loc_t = torch.Tensor(num, num_priors, 4)
        conf_t = torch.LongTensor(num, num_priors)
        for idx in range(num):
            truths = targets[idx][:, :-1].data
            labels = targets[idx][:, -1].data
            defaults = priors.data
            match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx)
        if GPU:
            loc_t = loc_t.cuda()
            conf_t = conf_t.cuda()
        # wrap targets
        loc_t = Variable(loc_t, requires_grad=False)
        conf_t = Variable(conf_t, requires_grad=False)

        pos = conf_t > 0

        # Localization Loss (Smooth L1)
        # Shape: [batch,num_priors,4]
        pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
        loc_p = loc_data[pos_idx].view(-1, 4)
        loc_t = loc_t[pos_idx].view(-1, 4)
        loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)

        # Compute max conf across batch for hard negative mining
        batch_conf = conf_data.view(-1, self.num_classes)
        loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))

        # Hard Negative Mining
        loss_c[pos.view(-1, 1)] = 0  # filter out pos boxes for now
        loss_c = loss_c.view(num, -1)
        _, loss_idx = loss_c.sort(1, descending=True)
        _, idx_rank = loss_idx.sort(1)
        num_pos = pos.long().sum(1, keepdim=True)

        constant_min = torch.ones(num_pos.shape, dtype=torch.int64) * 4
        neg_min = torch.max(self.negpos_ratio * num_pos, constant_min.cuda())
        num_neg = torch.clamp(neg_min, max=pos.size(1) - 1)
        neg = idx_rank < num_neg.expand_as(idx_rank)

        # Confidence Loss Including Positive and Negative Examples
        pos_idx = pos.unsqueeze(2).expand_as(conf_data)
        neg_idx = neg.unsqueeze(2).expand_as(conf_data)
        conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes)
        targets_weighted = conf_t[(pos + neg).gt(0)]
        loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)

        # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N

        N = max(num_pos.data.sum().float(), 1)
        loss_l /= N
        loss_c /= N
        return loss_l, loss_c

 

你可能感兴趣的:(深度学习探索,图像处理)