siamfc++中的datapipeline具体实现

datapipeline首先出现在videoanalyst/data/adaptor_dataset.py,这里定义了dataset的创建,继承于torch.utils.data.dataset

self.datapipeline = datapipeline_builder.build(self.task,
                                               self.cfg,
                                               seed=seed)

在videoanalyst/data/datapipeline/builder.py中定义了datapipeline的整体实现。分别用到了

  1. sampler
  2. transformes
  3. target
sampler = build_sampler(task, cfg.sampler, seed=seed)
transformers = build_transformer(task, cfg.transformer, seed=seed)
target = build_target(task, cfg.target)

pipeline = []
pipeline.extend(transformers)
pipeline.append(target)

cfg = cfg.datapipeline
name = cfg.name
module = MODULES[name](sampler, pipeline)

siampler

在videoanalyst/data/sampler/builder.py定义了siampler的整体框架,具体又由submodules组成。
submodules分成datasets和filter。

 submodules_cfg = cfg.submodules

    dataset_cfg = submodules_cfg.dataset
    datasets = dataset_builder.build(task, dataset_cfg)

    if submodules_cfg.filter.name != "":
        filter_cfg = submodules_cfg.filter
        data_filter = filter_builder.build(task, filter_cfg)
    else:
        data_filter = None

    name = cfg.name
    module = MODULES[name](datasets, seed=seed, data_filter=data_filter)
  

在videoanalyst/data/sampler/sampler_impl/track_pair_sampler.py定义了siampler的实现,具体返回一对样本对。

sampled_data = dict(
    data1=data1,
    data2=data2,
    is_negative_pair=is_negative_pair,
)
datasets

例如videoanalyst/data/dataset/dataset_impl/lasot.py定义了dataset的返回值。

from videoanalyst.evaluation.got_benchmark.datasets import LaSOT

    def update_params(self):
        r"""
        an interface for update params
        """
        dataset_root = osp.realpath(self._hyper_params["dataset_root"])
        subset = self._hyper_params["subset"]
        check_integrity = self._hyper_params["check_integrity"]
        self._state["dataset"] = LaSOT(dataset_root,
                                       subset=subset,
                                       check_integrity=check_integrity)

    def __getitem__(self, item: int) -> Dict:
        img_files, anno = self._state["dataset"][item]

        anno = xywh2xyxy(anno)
        sequence_data = dict(image=img_files, anno=anno)

        return sequence_data

    def __len__(self):
        return len(self._state["dataset"])

在videoanalyst/evaluation/got_benchmark/datasets/lasot.py
定义了__getitem__返回值:

return img_files, anno

也就是sequence_data的返回值。

filter

在videoanalyst/data/filter/filter_impl/track_pair_filter.py中主要实现对data数据的判别处理,如果没有数据则True,如果目标过小过大长宽比过大则False。

    def __call__(self, data: Dict) -> bool:
        if data is None:
            return True
        im, anno = data["image"], data["anno"]
        if self._hyper_params["target_type"] == "bbox":
            bbox = xyxy2xywh(anno)
        elif self._hyper_params["target_type"] == "mask":
            bbox = cv2.boundingRect(anno)
        else:
            logger.error("unspported target type {} in filter".format(
                self._hyper_params["target_type"]))
            exit()
        filter_flag = filter_unreasonable_training_boxes(
            im, bbox, self._hyper_params)
        return filter_flag

此时我们再回头看看Siampler具体做什么

    Sample procedure:
    __getitem__
    │
    ├── _sample_track_pair                              #返回一对图片和标注,以dict封装
    │   ├── _sample_dataset                             #随机选择dataset
    │   ├── _sample_sequence_from_dataset               #随机选择seq
    │   ├── _sample_track_frame_from_static_image       #在图片数据集,如COCO
    │   └── _sample_track_frame_from_sequence           #在视频数据集, 如LaSOT
    │           └── _sample_pair_idx_pair_within_max_diff #在最大间隔范围内选一对图片
    │
    └── _sample_track_frame
        ├── _sample_dataset
        ├── _sample_sequence_from_dataset
        ├── _sample_track_frame_from_static_image (x2)
        └── _sample_track_pair_from_sequence

实现如下

    def __getitem__(self, item) -> dict:
        is_negative_pair = (self._state["rng"].rand() <
                            self._hyper_params["negative_pair_ratio"])
        data1 = data2 = None
        sample_try_num = 0
        while self.data_filter(data1) or self.data_filter(data2):
            if is_negative_pair:
                data1 = self._sample_track_frame()
                data2 = self._sample_track_frame()
            else:
                data1, data2 = self._sample_track_pair()
            data1["image"] = load_image(data1["image"])
            data2["image"] = load_image(data2["image"])
            sample_try_num += 1
        sampled_data = dict(
            data1=data1,
            data2=data2,
            is_negative_pair=is_negative_pair,
        )

        return sampled_data
        
    def _sample_track_pair(self) -> Tuple[Dict, Dict]:
        dataset_idx, dataset = self._sample_dataset()
        sequence_data = self._sample_sequence_from_dataset(dataset)
        len_seq = self._get_len_seq(sequence_data)
        if len_seq == 1 and not isinstance(sequence_data["anno"][0], list):
            # static image dataset
            data1 = self._sample_track_frame_from_static_image(sequence_data)
            data2 = deepcopy(data1)
        else:
            # video dataset
            data1, data2 = self._sample_track_pair_from_sequence(
                sequence_data, self._state["max_diffs"][dataset_idx])

        return data1, data2

    def _sample_track_frame(self) -> Dict:
        _, dataset = self._sample_dataset()
        sequence_data = self._sample_sequence_from_dataset(dataset)
        len_seq = self._get_len_seq(sequence_data)
        if len_seq == 1:
            # static image dataset
            data_frame = self._sample_track_frame_from_static_image(
                sequence_data)
        else:
            # video dataset
            data_frame = self._sample_track_frame_from_sequence(sequence_data)

        return data_frame

总归返回的是sampled_data = dict( data1=data1, data2=data2, is_negative_pair=is_negative_pair, )

transfromer

videoanalyst/data/transformer/transformer_impl/random_crop_transformer.py中实现原始image的random crop,仍保存在sampled_data

    def __call__(self, sampled_data: Dict) -> Dict:
        r"""
        sampled_data: Dict()
            input data
            Dict(data1=Dict(image, anno), data2=Dict(image, anno))
        """
        data1 = sampled_data["data1"]
        data2 = sampled_data["data2"]
        im_temp, bbox_temp = data1["image"], data1["anno"]
        im_curr, bbox_curr = data2["image"], data2["anno"]
        im_z, bbox_z, im_x, bbox_x, _, _ = crop_track_pair(
            im_temp,
            bbox_temp,
            im_curr,
            bbox_curr,
            config=self._hyper_params,
            rng=self._state["rng"])

        sampled_data["data1"] = dict(image=im_z, anno=bbox_z)
        sampled_data["data2"] = dict(image=im_x, anno=bbox_x)

        return sampled_data

target

videoanalyst/data/target/target_impl/densebox_target.py在transformer的基础上,生成三种label。

    def __call__(self, sampled_data: Dict) -> Dict:
        data_z = sampled_data["data1"]
        im_z, bbox_z = data_z["image"], data_z["anno"]

        data_x = sampled_data["data2"]
        im_x, bbox_x = data_x["image"], data_x["anno"]

        is_negative_pair = sampled_data["is_negative_pair"]

        # input tensor
        im_z = im_z.transpose(2, 0, 1)
        im_x = im_x.transpose(2, 0, 1)

        # training target
        cls_label, ctr_label, box_label = make_densebox_target(
            bbox_x.reshape(1, 4), self._hyper_params)
        if is_negative_pair:
            cls_label[cls_label == 0] = -1
            cls_label[cls_label == 1] = 0

        training_data = dict(
            im_z=im_z,
            im_x=im_x,
            bbox_z=bbox_z,
            bbox_x=bbox_x,
            cls_gt=cls_label,
            ctr_gt=ctr_label,
            box_gt=box_label,
            is_negative_pair=int(is_negative_pair),
        )
        #training_data = super().__call__(training_data)

        return training_data

此时回头看datapipeline,其实是siampler从dataset中选出图片对,transformer根据x-z-size crop,target生成label。

Dataloader

最后回到全部的起点,在videoanalyst/data/builder.py定义了自己的dataset类,videoanalyst/data/adaptor_dataset.py定义了类的实现,才出现了datapipeline。
可以看到返回值是

    def __getitem__(self, item):
        if self.datapipeline is None:
            # build datapipeline with random seed the first time when __getitem__ is called
            # usually, dataset is already spawned (into subprocess) at this point.
            seed = (torch.initial_seed() + item * self._SEED_STEP +
                    self.ext_seed * self._EXT_SEED_STEP) % self._SEED_DIVIDER
            self.datapipeline = datapipeline_builder.build(self.task,
                                                           self.cfg,
                                                           seed=seed)
            logger.info("AdaptorDataset #%d built datapipeline with seed=%d" %
                        (item, seed))

        training_data = self.datapipeline[item]

        return training_data

你可能感兴趣的:(siamfc++解析,pytorch)