目标检测YOLOV3-训练集真实标注label预处理(三)

目录

1. 预处理标注文件

2. 制作数据生成器

3. 测试 


文件下载:https://download.csdn.net/download/qq_37116150/12289213

1. 预处理标注文件

首先将全局变量定义完成:

IMGSZ = 512 # 输入图片尺寸大小,必须是512x512
GRIDSZ = 16 # 网络最后输出的尺寸, 16x16
num_classes = 3 # 标签类别,共有3类,其中一类是背景,实际就两类,可根据需要修改
# YOLO的anchors的5个尺寸
ANCHORS = [0.57273, 0.677385, 1.87446, 2.06253, 3.33843, 5.47434, 7.88282, 3.52778, 9.77052, 9.16828]

为了损失函数,需要将标注信息的格式修改,应符合网络输出的格式: [b, 16, 16, 5, (4+1+num_classes)],具体代码如下:

def process_true_boxes(gt_boxes, anchors):
    """
    计算一张图片的真实标签信息
    :param gt_boxes:
    :param anchors:
    :return:
    """
    # gt_boxes: [40,5] 一张真实标签的位置坐标信息,40是虚数,根据实际情况来定
    # 512//16=32
    # 计算网络模型从输入到输出的缩小比例
    scale = IMGSZ // GRIDSZ
    # [5,2] 将anchors转化为矩阵形式,一行代表一个anchors
    anchors = np.array(anchors).reshape((5, 2))

    # mask for object
    # 用来判断该方格位置的anchors有没有目标,每个方格有5个anchors
    detector_mask = np.zeros([GRIDSZ, GRIDSZ, 5, 1])
    # x-y-w-h-l
    # 在输出方格的尺寸上[16, 16, 5]制作真实标签, 用于和预测输出值做比较,计算损失值
    matching_gt_box = np.zeros([GRIDSZ, GRIDSZ, 5, 5])
    # [40,5] x1-y1-x2-y2-l => x-y-w-h-l
    # 制作一个numpy变量,用于存储一张图片真实标签转换格式后的数据
    # 将左上角与右下角坐标转化为中心坐标与宽高的形式
    # [x_min, y_min, x_max, y_max] => [x_center, y_center, w, h]
    gt_boxes_grid = np.zeros(gt_boxes.shape)
    # DB: tensor => numpy 方便计算
    gt_boxes = gt_boxes.numpy()

    for i,box in enumerate(gt_boxes): # [40,5]
        # box: [5], x1-y1-x2-y2-l,逐行读取
        # 512 => 16
        # 将左上角与右下角坐标转化为中心坐标与宽高的形式
        # [x_min, y_min, x_max, y_max] => [x_center, y_center, w, h]
        x = ((box[0]+box[2])/2)/scale
        y = ((box[1]+box[3])/2)/scale
        w = (box[2] - box[0]) / scale
        h = (box[3] - box[1]) / scale
        # [40,5] x_center-y_center-w-h-l
        # 将第 i 行的数据赋予计算得到的新数据
        gt_boxes_grid[i] = np.array([x,y,w,h,box[4]])

        if w*h > 0: # valid box
            # 用于筛选有效数据,当w, h为0时,表明该行没有目标,为无效的填充数据0
            # x,y: 7.3, 6.8 都是缩放后的中心坐标
            best_anchor = 0
            best_iou = 0
            for j in range(5):
                # 计算真实目标框有5个anchros的交并比,选出做好的一个anchors
                interct = np.minimum(w, anchors[j,0]) * np.minimum(h, anchors[j,1])
                union = w*h + (anchors[j,0]*anchors[j,1]) - interct
                iou = interct / union

                if iou > best_iou: # best iou 筛选最大的iou,即最好的anchors
                    best_anchor = j # 将更加优秀的anchors的索引赋值与之前定义好的变量
                    best_iou = iou # 记录最好的iou
            # found the best anchors
            if best_iou>0: #用于判断是否有anchors与真实目标产生交并
               # 向下取整,即是将中心点坐标转化为左上角坐标, 用于后续计算赋值
               x_coord = np.floor(x).astype(np.int32)
               y_coord = np.floor(y).astype(np.int32)
               # [b,h,w,5,1]
               # 将最好的一个anchors赋值1,别的anchors默认为0
               # 图像坐标系的坐标与数组的坐标互为转置:[x,y] => [y, x]
               detector_mask[y_coord, x_coord, best_anchor] = 1
               # [b,h,w,5,x-y-w-h-l]
               # 将最好的一个anchors赋值真实标签的信息[x_center, y_center, w, h, label],别的anchors默认为0
               matching_gt_box[y_coord, x_coord, best_anchor] = \
                   np.array([x,y,w,h,box[4]])

    # [40,5] => [16,16,5,5]
    # matching_gt_box:[16,16,5,5],用于计算损失值
    # detector_mask:[16,16,5,1],掩码,判断哪个anchors有目标
    # gt_boxes_grid:[40,5],一张图片中目标的位置信息,转化后的格式
    return matching_gt_box, detector_mask, gt_boxes_grid

2. 制作数据生成器

为了后续训练,方便数据使用,制作一个数据生成器,代码如下:

def ground_truth_generator(db):
    """
    构建一个训练数据集迭代器,每次迭代的数量由batch决定
    :param db:训练集队列,包含训练集原图片数据信息,标签位置[x_min, y_min, x_max, y_max, label]信息
    :return:
    """
    for imgs, imgs_boxes in db:
        # imgs: [b,512,512,3] b的值由之前定义的batch_size来决定
        # imgs_boxes: [b,40,5],不一定是40,要根据实际情况来判断

        # 创建三个批量数据列表
        # 对应上面函数的单个图片数据变量
        batch_matching_gt_box = []
        batch_detector_mask = []
        batch_gt_boxes_grid = []

        # print(imgs_boxes[0,:5])

        b = imgs.shape[0] # 计算一个batch有多少张图片
        for i in range(b): # for each image
            matching_gt_box, detector_mask, gt_boxes_grid = \
                process_true_boxes(gt_boxes=imgs_boxes[i], anchors=ANCHORS)
            batch_matching_gt_box.append(matching_gt_box)
            batch_detector_mask.append(detector_mask)
            batch_gt_boxes_grid.append(gt_boxes_grid)
        # 将其转化为矩阵形式并转化为tensor,[b, 16,16,5,1]
        detector_mask = tf.cast(np.array(batch_detector_mask), dtype=tf.float32)
        # 将其转化为矩阵形式并转化为tensor,[b,16,16,5,5] x_center-y_center-w-h-l
        matching_gt_box = tf.cast(np.array(batch_matching_gt_box), dtype=tf.float32)
        # 将其转化为矩阵形式并转化为tensor,[b,40,5] x_center-y_center-w-h-l
        gt_boxes_grid = tf.cast(np.array(batch_gt_boxes_grid), dtype=tf.float32)

        # [b,16,16,5]
        # 将所有的label信息单独分出来,用于后续计算分类损失值
        matching_classes = tf.cast(matching_gt_box[...,4], dtype=tf.int32)
        # 将标签进行独热码编码 [b,16,16,5,num_classes:3],
        matching_classes_oh = tf.one_hot(matching_classes, depth=num_classes)
        # 将背景标签去除,背景为0
        # x_center-y_center-w-h-conf-l0-l1-l2 => x_center-y_center-w-h-conf-l1-l2
        # [b,16,16,5,2]
        matching_classes_oh = tf.cast(matching_classes_oh[...,1:], dtype=tf.float32)


        # [b,512,512,3]
        # [b,16,16,5,1]
        # [b,16,16,5,5]
        # [b,16,16,5,2]
        # [b,40,5]
        yield imgs, detector_mask, matching_gt_box, matching_classes_oh,gt_boxes_grid

3. 测试

测试代码如下:

# if __name__ == "__main__":
#     ## 训练集路径
#     img_path = "data\\val\\image"
#     annotation_path = "data\\val\\annotation"
#     label = ("sugarbeet", "weed")
#     train_db = get_datasets(img_path, annotation_path, label)
#
#     train_gen = ground_truth_generator(train_db)
#
#     img, detector_mask, matching_gt_box, matching_classes_oh, gt_boxes_grid = \
#         next(train_gen)
#     img, detector_mask, matching_gt_box, matching_classes_oh, gt_boxes_grid = \
#         img[0], detector_mask[0], matching_gt_box[0], matching_classes_oh[0], gt_boxes_grid[0]
#
#     fig, (ax1, ax2) = plt.subplots(2, figsize=(5, 10))
#     ax1.imshow(img)
#     # [16,16,5,1] => [16,16,1]
#     mask = tf.reduce_sum(detector_mask, axis=2)
#     ax2.matshow(mask[..., 0])  # [16,16]
#     plt.show()

未完待续。。。。。

你可能感兴趣的:(深度学习)