siamfc++中的trainer的具体实现

trainer = MODULES[name](optimizer, dataloader, monitors)

在videoanalyst/engine/trainer/trainer_base.py看看参数:

    default_hyper_params = dict(
        exp_name="default_training",
        exp_save="snapshots",
        max_epoch=20,
    )

    def __init__(self, optimizer, dataloader, monitors=[]):
        self._hyper_params = deepcopy(
            self.default_hyper_params)  # mapping-like object
        self._state = dict()  # pipeline state
        self._model = optimizer._model
        self._losses = optimizer._model.loss
        self._optimizer = optimizer
        self._monitors = monitors
        self._dataloader = iter(dataloader)  # get the iterabel data loader

这里首先定义的monitors

monitors

videoanalyst/engine/monitor/monitor_impl/text_info.py直接print每个epoch中的学习率、loss、extra、time。

    def update(self, engine_data: Dict):
        r"""
        """
        # state
        engine_state = self._state["engine_state"]
        # data
        schedule_info = engine_data["schedule_info"]
        training_losses = engine_data["training_losses"]
        extras = engine_data["extras"]
        time_dict = engine_data["time_dict"]
        # schedule information
        epoch = engine_state["epoch"]
        print_str = 'epoch %d, ' % epoch
        for k in schedule_info:
            print_str += '%s: %.1e, ' % (k, schedule_info[k])
        # loss info
        for k in training_losses:
            l = training_losses[k]
            print_str += '%s: %.3f, ' % (k, l.detach().cpu().numpy())
        # extra info
        for extra in extras.values():
            #if extra:
            #    extra = dist_utils.reduce_dict(extra)
            for k in extra:
                l = extra[k]
                print_str += '%s: %.3f, ' % (k, l)
        # pring elapsed time
        for k in time_dict:
            print_str += "%s: %.1e, " % (k, time_dict[k])
        max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
        print_str += " max mem: {:.1f}M".format(max_mem_mb)

        engine_state["print_str"] = print_str

同样videoanalyst/engine/monitor/monitor_impl/tensorboard_logger.py定义了update函数

    def update(self, engine_data: Dict):
        # from engine state calculate global step
        engine_state = self._state["engine_state"]
        epoch = engine_state["epoch"]
        max_epoch = engine_state["max_epoch"]
        iteration = engine_state["iteration"]
        max_iteration = engine_state["max_iteration"]
        global_step = iteration + epoch * max_iteration

        # build at first update
        if self._state["writer"] is None:
            self._build_writer(global_step=global_step)
            logger.info(
                "Tensorboard writer built, starts recording from global_step=%d"
                % global_step, )
            logger.info(
                "epoch=%d, max_epoch=%d, iteration=%d, max_iteration=%d" %
                (epoch, max_epoch, iteration, max_iteration))
        writer = self._state["writer"]

        # traverse engine_data and put to scalar
        self._add_scalar_recursively(writer, engine_data, "", global_step)

trainer

只需要看videoanalyst/engine/trainer/trainer_impl/regular_trainer.py

    def train(self):
        if not self._state["initialized"]:
            self.init_train()
        self._state["initialized"] = True

        self._state["epoch"] += 1
        epoch = self._state["epoch"]
        num_iterations = self._hyper_params["num_iterations"]

        # udpate engine_state
        self._state["max_epoch"] = self._hyper_params["max_epoch"]
        self._state["max_iteration"] = num_iterations

        self._optimizer.modify_grad(epoch)
        pbar = tqdm(range(num_iterations))
        self._state["pbar"] = pbar
        self._state["print_str"] = ""

        time_dict = OrderedDict()
        for iteration, _ in enumerate(pbar):
            self._state["iteration"] = iteration
            with Timer(name="data", output_dict=time_dict):
                training_data = next(self._dataloader)
            training_data = move_data_to_device(training_data,
                                                self._state["devices"][0])

            schedule_info = self._optimizer.schedule(epoch, iteration)
            self._optimizer.zero_grad()

            # forward propagation
            with Timer(name="fwd", output_dict=time_dict):
                predict_data = self._model(training_data)
                training_losses, extras = OrderedDict(), OrderedDict()
                for loss_name, loss in self._losses.items():
                    training_losses[loss_name], extras[loss_name] = loss(
                        predict_data, training_data)
                total_loss = sum(training_losses.values())

            # backward propagation
            with Timer(name="bwd", output_dict=time_dict):
                if self._optimizer.grad_scaler is not None:
                    self._optimizer.grad_scaler.scale(total_loss).backward()
                else:
                    total_loss.backward()
            self._optimizer.modify_grad(epoch, iteration)
            with Timer(name="optim", output_dict=time_dict):
                self._optimizer.step()

            trainer_data = dict(
                schedule_info=schedule_info,
                training_losses=training_losses,
                extras=extras,
                time_dict=time_dict,
            )

            for monitor in self._monitors:
                monitor.update(trainer_data)
            del training_data
            print_str = self._state["print_str"]
            pbar.set_description(print_str)

已知trainingdata的形式是

        training_data = dict(
            im_z=im_z,
            im_x=im_x,
            bbox_z=bbox_z,
            bbox_x=bbox_x,
            cls_gt=cls_label,
            ctr_gt=ctr_label,
            box_gt=box_label,
            is_negative_pair=int(is_negative_pair),
        )

这里关注predict_data = self._model(training_data)是重点,在videoanalyst/model/task_model/taskmodel_impl/siamese_track.py定义了forward函数

    def train_forward(self, training_data):
        target_img = training_data["im_z"]
        search_img = training_data["im_x"]
        # backbone feature
        f_z = self.basemodel(target_img)
        f_x = self.basemodel(search_img)
        # feature adjustment
        c_z_k = self.c_z_k(f_z)
        r_z_k = self.r_z_k(f_z)
        c_x = self.c_x(f_x)
        r_x = self.r_x(f_x)
        # feature matching
        r_out = xcorr_depthwise(r_x, r_z_k)
        c_out = xcorr_depthwise(c_x, c_z_k)
        # head
        fcos_cls_score_final, fcos_ctr_score_final, fcos_bbox_final, corr_fea = self.head(
            c_out, r_out)
        predict_data = dict(
            cls_pred=fcos_cls_score_final,
            ctr_pred=fcos_ctr_score_final,
            box_pred=fcos_bbox_final,
        )
        if self._hyper_params["corr_fea_output"]:
            predict_data["corr_fea"] = corr_fea
        return predict_data

根据pair图片对出结果predict_data。

你可能感兴趣的:(siamfc++解析,pytorch)