datapipeline首先出现在videoanalyst/data/adaptor_dataset.py,这里定义了dataset的创建,继承于torch.utils.data.dataset
self.datapipeline = datapipeline_builder.build(self.task,
self.cfg,
seed=seed)
在videoanalyst/data/datapipeline/builder.py中定义了datapipeline的整体实现。分别用到了
sampler = build_sampler(task, cfg.sampler, seed=seed)
transformers = build_transformer(task, cfg.transformer, seed=seed)
target = build_target(task, cfg.target)
pipeline = []
pipeline.extend(transformers)
pipeline.append(target)
cfg = cfg.datapipeline
name = cfg.name
module = MODULES[name](sampler, pipeline)
在videoanalyst/data/sampler/builder.py定义了siampler的整体框架,具体又由submodules组成。
submodules分成datasets和filter。
submodules_cfg = cfg.submodules
dataset_cfg = submodules_cfg.dataset
datasets = dataset_builder.build(task, dataset_cfg)
if submodules_cfg.filter.name != "":
filter_cfg = submodules_cfg.filter
data_filter = filter_builder.build(task, filter_cfg)
else:
data_filter = None
name = cfg.name
module = MODULES[name](datasets, seed=seed, data_filter=data_filter)
在videoanalyst/data/sampler/sampler_impl/track_pair_sampler.py定义了siampler的实现,具体返回一对样本对。
sampled_data = dict(
data1=data1,
data2=data2,
is_negative_pair=is_negative_pair,
)
例如videoanalyst/data/dataset/dataset_impl/lasot.py定义了dataset的返回值。
from videoanalyst.evaluation.got_benchmark.datasets import LaSOT
def update_params(self):
r"""
an interface for update params
"""
dataset_root = osp.realpath(self._hyper_params["dataset_root"])
subset = self._hyper_params["subset"]
check_integrity = self._hyper_params["check_integrity"]
self._state["dataset"] = LaSOT(dataset_root,
subset=subset,
check_integrity=check_integrity)
def __getitem__(self, item: int) -> Dict:
img_files, anno = self._state["dataset"][item]
anno = xywh2xyxy(anno)
sequence_data = dict(image=img_files, anno=anno)
return sequence_data
def __len__(self):
return len(self._state["dataset"])
在videoanalyst/evaluation/got_benchmark/datasets/lasot.py
定义了__getitem__返回值:
return img_files, anno
也就是sequence_data的返回值。
在videoanalyst/data/filter/filter_impl/track_pair_filter.py中主要实现对data数据的判别处理,如果没有数据则True,如果目标过小过大长宽比过大则False。
def __call__(self, data: Dict) -> bool:
if data is None:
return True
im, anno = data["image"], data["anno"]
if self._hyper_params["target_type"] == "bbox":
bbox = xyxy2xywh(anno)
elif self._hyper_params["target_type"] == "mask":
bbox = cv2.boundingRect(anno)
else:
logger.error("unspported target type {} in filter".format(
self._hyper_params["target_type"]))
exit()
filter_flag = filter_unreasonable_training_boxes(
im, bbox, self._hyper_params)
return filter_flag
此时我们再回头看看Siampler具体做什么
Sample procedure:
__getitem__
│
├── _sample_track_pair #返回一对图片和标注,以dict封装
│ ├── _sample_dataset #随机选择dataset
│ ├── _sample_sequence_from_dataset #随机选择seq
│ ├── _sample_track_frame_from_static_image #在图片数据集,如COCO
│ └── _sample_track_frame_from_sequence #在视频数据集, 如LaSOT
│ └── _sample_pair_idx_pair_within_max_diff #在最大间隔范围内选一对图片
│
└── _sample_track_frame
├── _sample_dataset
├── _sample_sequence_from_dataset
├── _sample_track_frame_from_static_image (x2)
└── _sample_track_pair_from_sequence
实现如下
def __getitem__(self, item) -> dict:
is_negative_pair = (self._state["rng"].rand() <
self._hyper_params["negative_pair_ratio"])
data1 = data2 = None
sample_try_num = 0
while self.data_filter(data1) or self.data_filter(data2):
if is_negative_pair:
data1 = self._sample_track_frame()
data2 = self._sample_track_frame()
else:
data1, data2 = self._sample_track_pair()
data1["image"] = load_image(data1["image"])
data2["image"] = load_image(data2["image"])
sample_try_num += 1
sampled_data = dict(
data1=data1,
data2=data2,
is_negative_pair=is_negative_pair,
)
return sampled_data
def _sample_track_pair(self) -> Tuple[Dict, Dict]:
dataset_idx, dataset = self._sample_dataset()
sequence_data = self._sample_sequence_from_dataset(dataset)
len_seq = self._get_len_seq(sequence_data)
if len_seq == 1 and not isinstance(sequence_data["anno"][0], list):
# static image dataset
data1 = self._sample_track_frame_from_static_image(sequence_data)
data2 = deepcopy(data1)
else:
# video dataset
data1, data2 = self._sample_track_pair_from_sequence(
sequence_data, self._state["max_diffs"][dataset_idx])
return data1, data2
def _sample_track_frame(self) -> Dict:
_, dataset = self._sample_dataset()
sequence_data = self._sample_sequence_from_dataset(dataset)
len_seq = self._get_len_seq(sequence_data)
if len_seq == 1:
# static image dataset
data_frame = self._sample_track_frame_from_static_image(
sequence_data)
else:
# video dataset
data_frame = self._sample_track_frame_from_sequence(sequence_data)
return data_frame
总归返回的是sampled_data = dict( data1=data1, data2=data2, is_negative_pair=is_negative_pair, )
videoanalyst/data/transformer/transformer_impl/random_crop_transformer.py中实现原始image的random crop,仍保存在sampled_data
def __call__(self, sampled_data: Dict) -> Dict:
r"""
sampled_data: Dict()
input data
Dict(data1=Dict(image, anno), data2=Dict(image, anno))
"""
data1 = sampled_data["data1"]
data2 = sampled_data["data2"]
im_temp, bbox_temp = data1["image"], data1["anno"]
im_curr, bbox_curr = data2["image"], data2["anno"]
im_z, bbox_z, im_x, bbox_x, _, _ = crop_track_pair(
im_temp,
bbox_temp,
im_curr,
bbox_curr,
config=self._hyper_params,
rng=self._state["rng"])
sampled_data["data1"] = dict(image=im_z, anno=bbox_z)
sampled_data["data2"] = dict(image=im_x, anno=bbox_x)
return sampled_data
videoanalyst/data/target/target_impl/densebox_target.py在transformer的基础上,生成三种label。
def __call__(self, sampled_data: Dict) -> Dict:
data_z = sampled_data["data1"]
im_z, bbox_z = data_z["image"], data_z["anno"]
data_x = sampled_data["data2"]
im_x, bbox_x = data_x["image"], data_x["anno"]
is_negative_pair = sampled_data["is_negative_pair"]
# input tensor
im_z = im_z.transpose(2, 0, 1)
im_x = im_x.transpose(2, 0, 1)
# training target
cls_label, ctr_label, box_label = make_densebox_target(
bbox_x.reshape(1, 4), self._hyper_params)
if is_negative_pair:
cls_label[cls_label == 0] = -1
cls_label[cls_label == 1] = 0
training_data = dict(
im_z=im_z,
im_x=im_x,
bbox_z=bbox_z,
bbox_x=bbox_x,
cls_gt=cls_label,
ctr_gt=ctr_label,
box_gt=box_label,
is_negative_pair=int(is_negative_pair),
)
#training_data = super().__call__(training_data)
return training_data
此时回头看datapipeline,其实是siampler从dataset中选出图片对,transformer根据x-z-size crop,target生成label。
最后回到全部的起点,在videoanalyst/data/builder.py定义了自己的dataset类,videoanalyst/data/adaptor_dataset.py定义了类的实现,才出现了datapipeline。
可以看到返回值是
def __getitem__(self, item):
if self.datapipeline is None:
# build datapipeline with random seed the first time when __getitem__ is called
# usually, dataset is already spawned (into subprocess) at this point.
seed = (torch.initial_seed() + item * self._SEED_STEP +
self.ext_seed * self._EXT_SEED_STEP) % self._SEED_DIVIDER
self.datapipeline = datapipeline_builder.build(self.task,
self.cfg,
seed=seed)
logger.info("AdaptorDataset #%d built datapipeline with seed=%d" %
(item, seed))
training_data = self.datapipeline[item]
return training_data