FaceBook的SlowFast画框demo代码修改

关于画框源码部分分析在此处

一 修改方案

1 修改utils.py py342 的TaskInfo类

class TaskInfo:
    def __init__(self):
        self.frames = None
        self.id = -1
        self.bboxes = None
        self.action_preds = None
        self.detect_preds = None
        self.num_buffer_frames = 0
        self.img_height = -1
        self.img_width = -1
        self.crop_size = -1
        self.clip_vis_size = -1

    def add_frames(self, idx, frames):
        """
        Add the clip and corresponding id.
        Args:
            idx (int): the current index of the clip.
            frames (list[ndarray]): list of images in "BGR" format.
        """
        self.frames = frames
        self.id = idx

    def add_bboxes(self, bboxes):
        """
        Add correspondding bounding boxes.
        """
        self.bboxes = bboxes

    def add_action_preds(self, preds):
        """
        Add the corresponding action predictions.
        """
        self.action_preds = preds
    def add_detect_preds(self, preds):
        """
        Add the corresponding action predictions.
        """
        self.action_preds = preds

2 修改detectron2调用部分增添检测框部分代码

predictor.py py157

from slowfast.visualization.video_visualizer import get_class_names
    class Detectron2Predictor:
    """
    Wrapper around Detectron2 to return the required predicted bounding boxes
    as a ndarray.
    """

    def __init__(self, cfg, gpu_id=None):
        """
        Args:
            cfg (CfgNode): configs. Details can be found in
                slowfast/config/defaults.py
            gpu_id (Optional[int]): GPU id.
        """

        self.cfg = get_cfg()
        self.cfg.merge_from_file(
            model_zoo.get_config_file(cfg.DEMO.DETECTRON2_CFG)
        )
        self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = cfg.DEMO.DETECTRON2_THRESH
        self.cfg.MODEL.WEIGHTS = cfg.DEMO.DETECTRON2_WEIGHTS
        self.cfg.INPUT.FORMAT = cfg.DEMO.INPUT_FORMAT
        """ 从json文件中获取类别及编号 """
        self.detect_names,_,_ =get_class_names(cfg.DEMO.Detect_File_Path,None,None)
        if cfg.NUM_GPUS and gpu_id is None:
            gpu_id = torch.cuda.current_device()
        self.cfg.MODEL.DEVICE = (
            "cuda:{}".format(gpu_id) if cfg.NUM_GPUS > 0 else "cpu"
        )

        logger.info("Initialized Detectron2 Object Detection Model.")

        self.predictor = DefaultPredictor(self.cfg)

    def __call__(self, task):
        """
        Return bounding boxes predictions as a tensor.
        Args:
            task (TaskInfo object): task object that contain
                the necessary information for action prediction. (e.g. frames)
        Returns:
            task (TaskInfo object): the same task info object but filled with
                prediction values (a tensor) and the corresponding boxes for
                action detection task.
        """
        # middle_frame = task.frames[len(task.frames) // 2]
        # outputs = self.predictor(middle_frame)
        # # Get only human instances
        # mask = outputs["instances"].pred_classes == 0
        # pred_boxes = outputs["instances"].pred_boxes.tensor[mask]
        # task.add_bboxes(pred_boxes)

        middle_frame = task.frames[len(task.frames) // 2]
        outputs = self.predictor(middle_frame)
        # Get only human instances
        """类别及阈值划分"""
        mask = (outputs["instances"].scores>=0.7) & (outputs["instances"].pred_classes == 0)
        """获取预测框"""
        pred_boxes = outputs["instances"].pred_boxes.tensor[mask]
        """得到预测置信度"""
        scores=outputs["instances"].socres[mask].tolist()
        """获取类别标签"""
        pred_labels=outputs["instances"].pred_classes[mask]
        pred_labels=pred_labels.tolist()
        """进行标签匹配"""
        for i in range(len(pred_labels)):
            pred_labels[i]=self.detect_names[pred_labels[i]]
        preds=[
            "[{:.4f}] {}".format(s, labels) for s, labels in zip(scores,pred_labels)
        ]
        """加入预测标签"""
        task.add_detect_preds(preds)
        task.add_bboxes(pred_boxes)

        return task

3 修改对应的绘图函数

(1)async_predictor.py py 276

def draw_predictions(task, video_vis):
    """
    Draw prediction for the given task.
    Args:
        task (TaskInfo object): task object that contain
            the necessary information for visualization. (e.g. frames, preds)
            All attributes must lie on CPU devices.
        video_vis (VideoVisualizer object): the video visualizer object.
    """
    boxes = task.bboxes
    frames = task.frames
    preds = task.action_preds
    """ 提出检测类别"""
    detect_pred=task.detect_preds
    if boxes is not None:
        img_width = task.img_width
        img_height = task.img_height
        if boxes.device != torch.device("cpu"):
            boxes = boxes.cpu()
        boxes = cv2_transform.revert_scaled_boxes(
            task.crop_size, boxes, img_height, img_width
        )

    keyframe_idx = len(frames) // 2 - task.num_buffer_frames
    draw_range = [
        keyframe_idx - task.clip_vis_size,
        keyframe_idx + task.clip_vis_size,
    ]
    buffer = frames[: task.num_buffer_frames]
    frames = frames[task.num_buffer_frames :]
    if boxes is not None:
        if len(boxes) != 0:
        """修改 draw_clip_range函数,加入detect_pred变量"""
            frames = video_vis.draw_clip_range(
                frames,
                preds,
                detect_pred,
                boxes,
                keyframe_idx=keyframe_idx,
                draw_range=draw_range,
            )
    else:
        frames = video_vis.draw_clip_range(
            frames, preds, keyframe_idx=keyframe_idx, draw_range=draw_range
        )
    del task

    return buffer + frames

(2) video_visualizer.py py 514

"""加入detect_pred属性"""
    def draw_clip_range(
        self,
        frames,
        preds,
        detect_pred,
        bboxes=None,
        text_alpha=0.5,
        ground_truth=False,
        keyframe_idx=None,
        draw_range=None,
        repeat_frame=1,
    ):
        """
        Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip
        if bboxes is provided. Boxes will gradually fade in and out the clip, centered around
        the clip's central frame, within the provided `draw_range`.
        Args:
            frames (array-like): video data in the shape (T, H, W, C).
            preds (tensor): a tensor of shape (num_boxes, num_classes) that contains all of the confidence scores
                of the model. For recognition task or for ground_truth labels, input shape can be (num_classes,).
            bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates of the bounding boxes.
            text_alpha (float): transparency label of the box wrapped around text labels.
            ground_truth (bool): whether the prodived bounding boxes are ground-truth.
            keyframe_idx (int): the index of keyframe in the clip.
            draw_range (Optional[list[ints]): only draw frames in range [start_idx, end_idx] inclusively in the clip.
                If None, draw on the entire clip.
            repeat_frame (int): repeat each frame in draw_range for `repeat_frame` time for slow-motion effect.
        """
        if draw_range is None:
            draw_range = [0, len(frames) - 1]
        if draw_range is not None:
            draw_range[0] = max(0, draw_range[0])
            left_frames = frames[: draw_range[0]]
            right_frames = frames[draw_range[1] + 1 :]

        draw_frames = frames[draw_range[0] : draw_range[1] + 1]
        if keyframe_idx is None:
            keyframe_idx = len(frames) // 2

        img_ls = (
            list(left_frames)
            """修改draw_clip函数"""
            + self.draw_clip(
                draw_frames,
                preds,
                detect_pred,
                bboxes=bboxes,
                text_alpha=text_alpha,
                ground_truth=ground_truth,
                keyframe_idx=keyframe_idx - draw_range[0],
                repeat_frame=repeat_frame,
            )
            + list(right_frames)
        )

        return img_ls

(3)video_visualizer.py py 568

"""加入detect_pred参数"""
    def draw_clip(
        self,
        frames,
        preds,
        detect_pred,
        bboxes=None,
        text_alpha=0.5,
        ground_truth=False,
        keyframe_idx=None,
        repeat_frame=1,
    ):
        """
        Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip
        if bboxes is provided. Boxes will gradually fade in and out the clip, centered around
        the clip's central frame.
        Args:
            frames (array-like): video data in the shape (T, H, W, C).
            preds (tensor): a tensor of shape (num_boxes, num_classes) that contains all of the confidence scores
                of the model. For recognition task or for ground_truth labels, input shape can be (num_classes,).
            bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates of the bounding boxes.
            text_alpha (float): transparency label of the box wrapped around text labels.
            ground_truth (bool): whether the prodived bounding boxes are ground-truth.
            keyframe_idx (int): the index of keyframe in the clip.
            repeat_frame (int): repeat each frame in draw_range for `repeat_frame` time for slow-motion effect.
        """
        assert repeat_frame >= 1, "`repeat_frame` must be a positive integer."

        repeated_seq = range(0, len(frames))
        repeated_seq = list(
            itertools.chain.from_iterable(
                itertools.repeat(x, repeat_frame) for x in repeated_seq
            )
        )

        frames, adjusted = self._adjust_frames_type(frames)
        if keyframe_idx is None:
            half_left = len(repeated_seq) // 2
            half_right = (len(repeated_seq) + 1) // 2
        else:
            mid = int((keyframe_idx / len(frames)) * len(repeated_seq))
            half_left = mid
            half_right = len(repeated_seq) - mid

        alpha_ls = np.concatenate(
            [
                np.linspace(0, 1, num=half_left),
                np.linspace(1, 0, num=half_right),
            ]
        )
        text_alpha = text_alpha
        frames = frames[repeated_seq]
        img_ls = []
        for alpha, frame in zip(alpha_ls, frames):
        """修改draw_one_frame函数"""
            draw_img = self.draw_one_frame(
                frame,
                preds,
                detect_pred,
                bboxes,
                alpha=alpha,
                text_alpha=text_alpha,
                ground_truth=ground_truth,
            )
            if adjusted:
                draw_img = draw_img.astype("float32") / 255

            img_ls.append(draw_img)

        return img_ls

(4)video_visualizer.py py 404

"""加入detect_pred属性"""
    def draw_one_frame(
        self,
        frame,
        preds,
        detect_pred,
        bboxes=None,
        alpha=0.5,
        text_alpha=0.7,
        ground_truth=False,
    ):
        """
        Draw labels and bouding boxes for one image. By default, predicted labels are drawn in
        the top left corner of the image or corresponding bounding boxes. For ground truth labels
        (setting True for ground_truth flag), labels will be drawn in the bottom left corner.
        Args:
            frame (array-like): a tensor or numpy array of shape (H, W, C), where H and W correspond to
                the height and width of the image respectively. C is the number of
                color channels. The image is required to be in RGB format since that
                is a requirement of the Matplotlib library. The image is also expected
                to be in the range [0, 255].
            preds (tensor or list): If ground_truth is False, provide a float tensor of shape (num_boxes, num_classes)
                that contains all of the confidence scores of the model.
                For recognition task, input shape can be (num_classes,). To plot true label (ground_truth is True),
                preds is a list contains int32 of the shape (num_boxes, true_class_ids) or (true_class_ids,).
            bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates of the bounding boxes.
            alpha (Optional[float]): transparency level of the bounding boxes.
            text_alpha (Optional[float]): transparency level of the box wrapped around text labels.
            ground_truth (bool): whether the prodived bounding boxes are ground-truth.
        """
        if isinstance(preds, torch.Tensor):
            if preds.ndim == 1:
                preds = preds.unsqueeze(0)
            n_instances = preds.shape[0]
        elif isinstance(preds, list):
            n_instances = len(preds)
        else:
            logger.error("Unsupported type of prediction input.")
            return

        if ground_truth:
            top_scores, top_classes = [None] * n_instances, preds

        elif self.mode == "top-k":
            top_scores, top_classes = torch.topk(preds, k=self.top_k)
            top_scores, top_classes = top_scores.tolist(), top_classes.tolist()
        elif self.mode == "thres":
            top_scores, top_classes = [], []
            for pred in preds:
                mask = pred >= self.thres
                top_scores.append(pred[mask].tolist())
                top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist()
                top_classes.append(top_class)

        # Create labels top k predicted classes with their scores.
        text_labels = []
        for i in range(n_instances):
            text_labels.append(
                _create_text_labels(
                    top_classes[i],
                    top_scores[i],
                    self.class_names,
                    ground_truth=ground_truth,
                )
            )
        frame_visualizer = ImgVisualizer(frame, meta=None)
        font_size = min(
            max(np.sqrt(frame.shape[0] * frame.shape[1]) // 35, 5), 9
        )
        top_corner = not ground_truth
        if bboxes is not None:
            assert len(preds) == len(
                bboxes
            ), "Encounter {} predictions and {} bounding boxes".format(
                len(preds), len(bboxes)
            )
            for i, box in enumerate(bboxes):
                text = text_labels[i]
                """加入检测标签"""
                text.append(detect_pred[i])
                pred_class = top_classes[i]
                """默认目标检测标签颜色"""
                colors1=[(0.9019607843137225,0.9607843137254902,0.788235294117667)]
                """为行为检测标签颜色"""
                colors2 = [self._get_color(pred) for pred in pred_class]
                colors=colors1+colors2

                box_color = "r" if ground_truth else "g"
                line_style = "--" if ground_truth else "-."
                frame_visualizer.draw_box(
                    box,
                    alpha=alpha,
                    edge_color=box_color,
                    line_style=line_style,
                )
                frame_visualizer.draw_multiple_text(
                    text,
                    box,
                    top_corner=top_corner,
                    font_size=font_size,
                    box_facecolors=colors,
                    alpha=text_alpha,
                )
        else:
            text = text_labels[0]
            pred_class = top_classes[0]
            colors = [self._get_color(pred) for pred in pred_class]
            frame_visualizer.draw_multiple_text(
                text,
                torch.Tensor([0, 5, frame.shape[1], frame.shape[0] - 5]),
                top_corner=top_corner,
                font_size=font_size,
                box_facecolors=colors,
                alpha=text_alpha,
            )

        return frame_visualizer.output.get_image()

4 修改结果对比

修改前
FaceBook的SlowFast画框demo代码修改_第1张图片
修改后:
FaceBook的SlowFast画框demo代码修改_第2张图片

你可能感兴趣的:(深度学习,object,detection)