YOLO v3 PyTorch版本源码解读(二):数据集的读取以及网络的训练

PyTorch 代码链接:https://github.com/ultralytics/yolov3

本篇主要是对代码文件中 原始数据读取方式 的解读,集中分析了 utils 文件夹下的 datasets.py函数。并分析了损失函数的组成、mAP的计算。

1. datasets.py

1.1 letterbox函数

我们知道 yolo v3是支持任意尺寸的图片输入的,但我们定义的网络结构是输入尺寸是固定的,所以在读取 原始图片时肯定是做了一些处理的,将原始图片的尺寸统一转换为我们定义的网络输入尺寸如416等。

比如输入尺寸是416416,但原图是按照纵横比例缩放至416416的, 取 min(w/img_w, h/img_h)这个比例来缩放,保证长的边缩放为需要的输入尺寸416,而短边按比例缩放不会扭曲,img_w,img_h是原图尺寸768,576, 缩放后的尺寸为new_w, new_h=416,312,需要的输入尺寸是w,h=416416.如下图所示:
YOLO v3 PyTorch版本源码解读(二):数据集的读取以及网络的训练_第1张图片
剩下的灰色区域用(128,128,128)填充即可构造为416
416。这部分的代码对应datasets.py中的letterbox函数。

def letterbox(img, new_shape=416, color=(128, 128, 128), mode='auto'):
    # Resize a rectangular image to a 32 pixel multiple rectangle
    # https://github.com/ultralytics/yolov3/issues/232
    shape = img.shape[:2]  # current shape [height, width]
    # 按长边缩放,不足 new_shape * new_shape 的部分用128填充
    if isinstance(new_shape, int):
        ratio = float(new_shape) / max(shape)
    else:
        ratio = max(new_shape) / max(shape)  # ratio  = new / old
    # 缩放比例
    ratiow, ratioh = ratio, ratio
    # 缩放之后未填充的 高和宽
    new_unpad = (int(round(shape[1] * ratio)), int(round(shape[0] * ratio)))

    # Compute padding https://github.com/ultralytics/yolov3/issues/232
    if mode is 'auto':  # minimum rectangle
        dw = np.mod(new_shape - new_unpad[0], 32) / 2  # width padding
        dh = np.mod(new_shape - new_unpad[1], 32) / 2  # height padding
    elif mode is 'square':  # square
        dw = (new_shape - new_unpad[0]) / 2  # width padding
        dh = (new_shape - new_unpad[1]) / 2  # height padding
    elif mode is 'rect':  # square
        dw = (new_shape[1] - new_unpad[0]) / 2  # width padding
        dh = (new_shape[0] - new_unpad[1]) / 2  # height padding
    elif mode is 'scaleFill':
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape, new_shape)
        ratiow, ratioh = new_shape / shape[1], new_shape / shape[0]

    if shape[::-1] != new_unpad:
        img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_AREA)  # resize
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
    return img, ratiow, ratioh, dw, dh

1.2. LoadImagesAndLabels类

这个类很长,我们重点看一下对 原始图片的读取中调用了上面的letterbox函数,然后根据letterbox中的缩放系数以及填充情况对 新图片标签中的 x1y1x2y2 做了修正,接着 利用xyxy2xywh函数将 “左上角右下角表示bounding box” 的方法转换为了 “中心点和宽高表示”,接着对 x、y、w、h均除以了 bounding box的宽,做了归一化处理。这是为了之后训练时loss数值维度上的统一。可以看到的是之后在 训练过程中使用的真实标签中的 w,h 就是归一化之后的 w和h,而 x,y 则是通过gxy=(x,y)*ng, gxy -= gxy.floor() 来计算得到,这个在utils.py中的build_targets函数中可以看到。

class LoadImagesAndLabels(Dataset):  # for training/testing
    def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False):
        with open(path, 'r') as f:
            self.img_files = [x for x in f.read().splitlines() if os.path.splitext(x)[-1].lower() in img_formats]

        n = len(self.img_files)   # n = 64
        # print("n= ", n, batch_size)
        bi = np.floor(np.arange(n) / batch_size).astype(np.int)  # batch index
        nb = bi[-1] + 1  # number of batches         # nb = 4
        assert n > 0, 'No images found in %s' % path

        self.n = n
        self.batch = bi  # batch index of image
        self.img_size = img_size
        self.augment = augment
        self.hyp = hyp
        self.image_weights = image_weights
        self.rect = False if image_weights else rect

        # Define labels, 将self.img_files里 images 文件夹下面的 .jpg 全部换成 .txt作为标签地址
        self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt')
                            for x in self.img_files]
        # Rectangular Training  https://github.com/ultralytics/yolov3/issues/232
        if self.rect:
            print("self.rect is Ture")
            # Read image shapes
            sp = 'data' + os.sep + path.replace('.txt', '.shapes').split(os.sep)[-1]  # shapefile path
            try:
                with open(sp, 'r') as f:  # read existing shapefile
                    s = [x.split() for x in f.read().splitlines()]
                    assert len(s) == n, 'Shapefile out of sync'
            except:
                s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')]
                np.savetxt(sp, s, fmt='%g')  # overwrites existing (if any)

            # Sort by aspect ratio
            s = np.array(s, dtype=np.float64)
            ar = s[:, 1] / s[:, 0]  # aspect ratio
            i = ar.argsort()
            self.img_files = [self.img_files[i] for i in i]
            self.label_files = [self.label_files[i] for i in i]
            ar = ar[i]

            # Set training image shapes
            shapes = [[1, 1]] * nb
            for i in range(nb):
                ari = ar[bi == i]
                mini, maxi = ari.min(), ari.max()
                if maxi < 1:
                    shapes[i] = [maxi, 1]
                elif mini > 1:
                    shapes[i] = [1, 1 / mini]

            self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32.).astype(np.int) * 32

        # Preload labels (required for weighted CE training)
        self.imgs = [None] * n
        self.labels = [None] * n
        preload_labels = False
        if preload_labels:
            print("preload_labels is Ture")
            self.labels = [np.zeros((0, 5))] * n
            iter = tqdm(self.label_files, desc='Reading labels') if n > 10 else self.label_files
            extract_bounding_boxes = False
            for i, file in enumerate(iter):
                try:
                    with open(file, 'r') as f:
                        l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
                        if l.shape[0]:
                            assert l.shape[1] == 5, '> 5 label columns: %s' % file
                            assert (l >= 0).all(), 'negative labels: %s' % file
                            assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
                            self.labels[i] = l

                            # Extract object detection boxes for a second stage classifier
                            if extract_bounding_boxes:
                                p = Path(self.img_files[i])
                                img = cv2.imread(str(p))
                                h, w, _ = img.shape
                                for j, x in enumerate(l):
                                    f = '%s%sclassification%s%g_%g_%s' % (
                                        p.parent.parent, os.sep, os.sep, x[0], j, p.name)
                                    if not os.path.exists(Path(f).parent):
                                        os.makedirs(Path(f).parent)  # make new output folder
                                    box = xywh2xyxy(x[1:].reshape(-1, 4)).ravel()
                                    box = np.clip(box, 0, 1)  # clip boxes outside of image
                                    result = cv2.imwrite(f, img[int(box[1] * h):int(box[3] * h),
                                                            int(box[0] * w):int(box[2] * w)])
                                    if not result:
                                        print('stop')
                except:
                    pass  # print('Warning: missing labels for %s' % self.img_files[i])  # missing label file
            assert len(np.concatenate(self.labels, 0)) > 0, 'No labels found. Incorrect label paths provided.'

        # Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3
        detect_corrupted_images = False
        if detect_corrupted_images:
            from skimage import io  # conda install -c conda-forge scikit-image
            for file in tqdm(self.img_files, desc='Detecting corrupted images'):
                try:
                    _ = io.imread(file)
                except:
                    print('Corrupted image detected: %s' % file)

    def __len__(self):
        return len(self.img_files)

    # def __iter__(self):
    #     self.count = -1
    #     print('ran dataset iter')
    #     #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
    #     return self

    def __getitem__(self, index):
        if self.image_weights:
            index = self.indices[index]

        img_path = self.img_files[index]
        label_path = self.label_files[index]
        hyp = self.hyp

        # Load image
        img = self.imgs[index]
        if img is None:
            img = cv2.imread(img_path)  # BGR
            assert img is not None, 'File Not Found ' + img_path
            r = self.img_size / max(img.shape)  # size ratio
            if self.augment and r < 1:  # if training (NOT testing), downsize to inference shape
                h, w, _ = img.shape
                img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_AREA)

            if self.n < 3000:  # cache into memory if image count < 3000
                self.imgs[index] = img

        # Augment colorspace
        augment_hsv = True
        if self.augment and augment_hsv:
            # SV augmentation by 50%
            img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)  # hue, sat, val
            S = img_hsv[:, :, 1].astype(np.float32)  # saturation
            V = img_hsv[:, :, 2].astype(np.float32)  # value

            a = random.uniform(-1, 1) * hyp['hsv_s'] + 1
            b = random.uniform(-1, 1) * hyp['hsv_v'] + 1
            S *= a
            V *= b

            img_hsv[:, :, 1] = S if a < 1 else S.clip(None, 255)
            img_hsv[:, :, 2] = V if b < 1 else V.clip(None, 255)
            cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)

        # Letterbox
        h, w, _ = img.shape
        if self.rect:
            shape = self.batch_shapes[self.batch[index]]
            img, ratiow, ratioh, padw, padh = letterbox(img, new_shape=shape, mode='rect')
        else:
            shape = self.img_size
            img, ratiow, ratioh, padw, padh = letterbox(img, new_shape=shape, mode='square')

        # Load labels
        labels = []
        if os.path.isfile(label_path):
            x = self.labels[index]
            if x is None:  # labels not preloaded
                with open(label_path, 'r') as f:
                    x = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
                    self.labels[index] = x  # save for next time
            # 由于对原始图片做了缩放,所以这里的label也要做相应处理
            if x.size > 0:
                print("sssssssssssssssssss")
                # Normalized xywh to pixel xyxy format
                # 这里的 w 和 h 是原始图片的宽和高,ratiow 和 ratioh 都等于 new_shape[0]/new_shape[0]
                labels = x.copy()
                # labels = x1, y1, x2, y2
                labels[:, 1] = ratiow * w * (x[:, 1] - x[:, 3] / 2) + padw
                labels[:, 2] = ratioh * h * (x[:, 2] - x[:, 4] / 2) + padh
                labels[:, 3] = ratiow * w * (x[:, 1] + x[:, 3] / 2) + padw
                labels[:, 4] = ratioh * h * (x[:, 2] + x[:, 4] / 2) + padh

        # Augment image and labels
        if self.augment:
            img, labels = random_affine(img, labels,
                                        degrees=hyp['degrees'],
                                        translate=hyp['translate'],
                                        scale=hyp['scale'],
                                        shear=hyp['shear'])

        nL = len(labels)  # number of labels
        if nL:
            # convert xyxy to xywh
            labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])

            # Normalize coordinates 0 - 1
            labels[:, [2, 4]] /= img.shape[0]  # height
            labels[:, [1, 3]] /= img.shape[1]  # width

        if self.augment:
            # random left-right flip
            lr_flip = True
            if lr_flip and random.random() > 0.5:
                img = np.fliplr(img)
                if nL:
                    labels[:, 1] = 1 - labels[:, 1]

            # random up-down flip
            ud_flip = False
            if ud_flip and random.random() > 0.5:
                img = np.flipud(img)
                if nL:
                    labels[:, 2] = 1 - labels[:, 2]

        labels_out = torch.zeros((nL, 6))
        if nL:
            labels_out[:, 1:] = torch.from_numpy(labels)

        # Normalize
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img, dtype=np.float32)  # uint8 to float32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0

        return torch.from_numpy(img), labels_out, img_path, (h, w)

    @staticmethod
    def collate_fn(batch):
        img, label, path, hw = list(zip(*batch))  # transposed
        for i, l in enumerate(label):
            l[:, 0] = i  # add target image index for build_targets()
        return torch.stack(img, 0), torch.cat(label, 0), path, hw

2. utils.py

2.1 build_targets函数

从这里开始,我参考的pytorch版本为:https://github.com/eriklindernoren/PyTorch-YOLOv3
因为之前的那个版本写的不够清晰,而这个版本的代码我也验证过,所以后续都是针对这个版本分析,其中模块构建部分大同小异。

iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf = build_targets(
                pred_boxes=pred_boxes,
                pred_cls=pred_cls,
                target=targets,
                anchors=self.scaled_anchors,
                ignore_thres=self.ignore_thres,
def build_targets(pred_boxes, pred_cls, target, anchors, ignore_thres):

    ByteTensor = torch.cuda.ByteTensor if pred_boxes.is_cuda else torch.ByteTensor
    FloatTensor = torch.cuda.FloatTensor if pred_boxes.is_cuda else torch.FloatTensor
    
    # pred_boxes: (8, 3, 13, 13, 4)   targets: (number_of_groundtruth, 6)
    print("target_size: ", target.size())
    nB = pred_boxes.size(0)
    nA = pred_boxes.size(1)
    nC = pred_cls.size(-1)
    nG = pred_boxes.size(2)
    # pred_cls: (8, 3, 13, 13, 80)    anchors:(3, 2)
    # Output tensors
    # nB=8, nA=3, nG=13, nG=13, nC=80
    obj_mask = ByteTensor(nB, nA, nG, nG).fill_(0)
    noobj_mask = ByteTensor(nB, nA, nG, nG).fill_(1)
    class_mask = FloatTensor(nB, nA, nG, nG).fill_(0)
    iou_scores = FloatTensor(nB, nA, nG, nG).fill_(0)
    tx = FloatTensor(nB, nA, nG, nG).fill_(0)
    ty = FloatTensor(nB, nA, nG, nG).fill_(0)
    tw = FloatTensor(nB, nA, nG, nG).fill_(0)
    th = FloatTensor(nB, nA, nG, nG).fill_(0)
    tcls = FloatTensor(nB, nA, nG, nG, nC).fill_(0)

    # Convert to position relative to box
    target_boxes = target[:, 2:6] * nG
    gxy = target_boxes[:, :2]
    gwh = target_boxes[:, 2:]
    # Get anchors with best iou
    ious = torch.stack([bbox_wh_iou(anchor, gwh) for anchor in anchors])   # ious_size: (3, number_of_groundtruth)
    # best_ious, best_n size: number_of_groundtruth
    best_ious, best_n = ious.max(0)               
    # Separate target values
    b, target_labels = target[:, :2].long().t()
    # b= tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7])
    gx, gy = gxy.t()
    gw, gh = gwh.t()
    gi, gj = gxy.long().t()
    # Set masks
    obj_mask[b, best_n, gj, gi] = 1
    noobj_mask[b, best_n, gj, gi] = 0

    # Set noobj mask to zero where iou exceeds ignore threshold
    for i, anchor_ious in enumerate(ious.t()):
        noobj_mask[b[i], anchor_ious > ignore_thres, gj[i], gi[i]] = 0

    # Coordinates
    tx[b, best_n, gj, gi] = gx - gx.floor()
    ty[b, best_n, gj, gi] = gy - gy.floor()
    # Width and height
    tw[b, best_n, gj, gi] = torch.log(gw / anchors[best_n][:, 0] + 1e-16)
    th[b, best_n, gj, gi] = torch.log(gh / anchors[best_n][:, 1] + 1e-16)
    # One-hot encoding of label
    tcls[b, best_n, gj, gi, target_labels] = 1
    # Compute label correctness and iou at best anchor
    class_mask[b, best_n, gj, gi] = (pred_cls[b, best_n, gj, gi].argmax(-1) == target_labels).float()
    iou_scores[b, best_n, gj, gi] = bbox_iou(pred_boxes[b, best_n, gj, gi], target_boxes, x1y1x2y2=False)

    tconf = obj_mask.float()
    return iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf

3. 训练

3.1 损失函数

损失函数由3部分构成:位置预测损失、置信度损失、分类损失;每一个部分都是针对bounding box而言的,其中只有当某个bounding box对一个真实的ground truth负责时,才会有位置预测损失和分类损失,而置信度损失每个bounding box都有,不同的是对真实的ground truth负责的bounding box和不负责的拥有不用的权重系数。

# Loss : Mask outputs to ignore non-existing objects (except with conf. loss)
            loss_x = self.mse_loss(x[obj_mask], tx[obj_mask])
            loss_y = self.mse_loss(y[obj_mask], ty[obj_mask])
            loss_w = self.mse_loss(w[obj_mask], tw[obj_mask])
            loss_h = self.mse_loss(h[obj_mask], th[obj_mask])
            loss_conf_obj = self.bce_loss(pred_conf[obj_mask], tconf[obj_mask])
            loss_conf_noobj = self.bce_loss(pred_conf[noobj_mask], tconf[noobj_mask])
            loss_conf = self.obj_scale * loss_conf_obj + self.noobj_scale * loss_conf_noobj
            loss_cls = self.bce_loss(pred_cls[obj_mask], tcls[obj_mask])
            total_loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls

3.2 mAP计算

mAP本质是每个类别的 precision-recall 曲线与 x 轴围成的面积均值,那么每个类别的precision和recall是如何计算的呢?那么如何评价一个bounding box对真实ground truth预测正确与否呢?肯定是类别要相同、IOU要超过置信度阈值、置信度分数也要大于置信度阈值。

我们都知道, p r e c i s i o n = T P T P + F P , r e c a l l = T P T P + F N precision = \frac{TP}{TP+FP},recall=\frac{TP}{TP+FN} precision=TP+FPTPrecall=TP+FNTP,这里的TP的数目指的是真实的ground truth中预测也正确的比例;TP+FP的数目指的是预测正确的数目;TP+FN的数目指的是真实正确的数目。

# Metrics
            cls_acc = 100 * class_mask[obj_mask].mean()
            conf_obj = pred_conf[obj_mask].mean()
            conf_noobj = pred_conf[noobj_mask].mean()
            conf50 = (pred_conf > 0.5).float()
            iou50 = (iou_scores > 0.5).float()
            iou75 = (iou_scores > 0.75).float()
            detected_mask = conf50 * class_mask * tconf
            precision = torch.sum(iou50 * detected_mask) / (conf50.sum() + 1e-16)
            recall50 = torch.sum(iou50 * detected_mask) / (obj_mask.sum() + 1e-16)
            recall75 = torch.sum(iou75 * detected_mask) / (obj_mask.sum() + 1e-16)
def ap_per_class(tp, conf, pred_cls, target_cls):
    """ Compute the average precision, given the recall and precision curves.
    Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
    # Arguments
        tp:    True positives (list).
        conf:  Objectness value from 0-1 (list).
        pred_cls: Predicted object classes (list).
        target_cls: True object classes (list).
    # Returns
        The average precision as computed in py-faster-rcnn.
    """

    # Sort by objectness
    i = np.argsort(-conf)
    tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]

    # Find unique classes
    unique_classes = np.unique(target_cls)

    # Create Precision-Recall curve and compute AP for each class
    ap, p, r = [], [], []
    for c in tqdm.tqdm(unique_classes, desc="Computing AP"):
        i = pred_cls == c
        n_gt = (target_cls == c).sum()  # Number of ground truth objects
        n_p = i.sum()  # Number of predicted objects

        if n_p == 0 and n_gt == 0:
            continue
        elif n_p == 0 or n_gt == 0:
            ap.append(0)
            r.append(0)
            p.append(0)
        else:
            # Accumulate FPs and TPs
            fpc = (1 - tp[i]).cumsum()
            tpc = (tp[i]).cumsum()

            # Recall
            recall_curve = tpc / (n_gt + 1e-16)
            r.append(recall_curve[-1])

            # Precision
            precision_curve = tpc / (tpc + fpc)
            p.append(precision_curve[-1])

            # AP from recall-precision curve
            ap.append(compute_ap(recall_curve, precision_curve))

    # Compute F1 score (harmonic mean of precision and recall)
    p, r, ap = np.array(p), np.array(r), np.array(ap)
    f1 = 2 * p * r / (p + r + 1e-16)

    return p, r, ap, f1, unique_classes.astype("int32")


def compute_ap(recall, precision):
    """ Compute the average precision, given the recall and precision curves.
    Code originally from https://github.com/rbgirshick/py-faster-rcnn.

    # Arguments
        recall:    The recall curve (list).
        precision: The precision curve (list).
    # Returns
        The average precision as computed in py-faster-rcnn.
    """
    # correct AP calculation
    # first append sentinel values at the end
    mrec = np.concatenate(([0.0], recall, [1.0]))
    mpre = np.concatenate(([0.0], precision, [0.0]))

    # compute the precision envelope
    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

    # to calculate area under PR curve, look for points
    # where X axis (recall) changes value
    i = np.where(mrec[1:] != mrec[:-1])[0]

    # and sum (\Delta recall) * prec
    ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap


def get_batch_statistics(outputs, targets, iou_threshold):
    """ Compute true positives, predicted scores and predicted labels per sample """
    batch_metrics = []
    for sample_i in range(len(outputs)):

        if outputs[sample_i] is None:
            continue

        output = outputs[sample_i]
        pred_boxes = output[:, :4]
        pred_scores = output[:, 4]
        pred_labels = output[:, -1]

        true_positives = np.zeros(pred_boxes.shape[0])

        annotations = targets[targets[:, 0] == sample_i][:, 1:]
        target_labels = annotations[:, 0] if len(annotations) else []
        if len(annotations):
            detected_boxes = []
            target_boxes = annotations[:, 1:]

            for pred_i, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):

                # If targets are found break
                if len(detected_boxes) == len(annotations):
                    break

                # Ignore if label is not one of the target labels
                if pred_label not in target_labels:
                    continue

                iou, box_index = bbox_iou(pred_box.unsqueeze(0), target_boxes).max(0)
                if iou >= iou_threshold and box_index not in detected_boxes:
                    true_positives[pred_i] = 1
                    detected_boxes += [box_index]
        batch_metrics.append([true_positives, pred_scores, pred_labels])
    return batch_metrics

你可能感兴趣的:(目标检测论文解读)