【zoedepth】代码阅读与踩坑

训练

真正训练代码:zoedepth/trainers/base_trainer.py:181

每个batch zoedepth.trainers.zoedepth_trainer.Trainer.train_on_batch

losses = self.train_on_batch(batch, i)

报错

运行到这里报错

Training ZoeDepth
Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm 2023.1\plugins\python\helpers\pydev\_pydevd_bundle\pydevd_comm.py", line 304, in _on_run
    r = r.decode('utf-8')
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe6 in position 1023: unexpected end of data
Traceback (most recent call last):
  File "D:\code\Github_code\ZoeDepth\train_mono_tensorboard_load_pre_handle_bit.py", line 278, in 
    main_worker(config.gpu, ngpus_per_node, config, writer)
  File "D:\code\Github_code\ZoeDepth\train_mono_tensorboard_load_pre_handle_bit.py", line 196, in main_worker
    trainer.train()
  File "D:\code\Github_code\ZoeDepth\zoedepth\trainers\base_trainer.py", line 190, in train
    pbar = tqdm(enumerate(self.train_loader), desc=f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train",
  File "C:\conda\envs\zoe\lib\site-packages\torch\utils\data\dataloader.py", line 430, in __iter__
    self._iterator = self._get_iterator()
  File "C:\conda\envs\zoe\lib\site-packages\torch\utils\data\dataloader.py", line 381, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "C:\conda\envs\zoe\lib\site-packages\torch\utils\data\dataloader.py", line 1034, in __init__
    w.start()
  File "C:\conda\envs\zoe\lib\multiprocessing\process.py", line 121, in start
    self._popen = self._Popen(self)
  File "C:\conda\envs\zoe\lib\multiprocessing\context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "C:\conda\envs\zoe\lib\multiprocessing\context.py", line 327, in _Popen
    return Popen(process_obj)
  File "C:\conda\envs\zoe\lib\multiprocessing\popen_spawn_win32.py", line 93, in __init__
    reduction.dump(process_obj, to_child)
  File "C:\conda\envs\zoe\lib\multiprocessing\reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
KeyboardInterrupt

Process finished with exit code -1073741510 (0xC000013A: interrupted by Ctrl+C)

推理

    @torch.no_grad()
    def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, output_type: str="numpy", **kwargs) -> Union[np.ndarray, PIL.Image.Image, torch.Tensor]:
        """
        Inference interface for the model for PIL image
        Args:
            pil_img (PIL.Image.Image): input PIL image
            pad_input (bool, optional): whether to use padding augmentation. Defaults to True.
            with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True.
            output_type (str, optional): output type. Supported values are 'numpy', 'pil' and 'tensor'. Defaults to "numpy".
        """
        x = transforms.ToTensor()(pil_img).unsqueeze(0).to(self.device)
        out_tensor = self.infer(x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs)
        if output_type == "numpy":
            return out_tensor.squeeze().cpu().numpy()
        elif output_type == "pil":
            # uint16 is required for depth pil image
            out_16bit_numpy = (out_tensor.squeeze().cpu().numpy()*256).astype(np.uint16)
            return Image.fromarray(out_16bit_numpy)
        elif output_type == "tensor":
            return out_tensor.squeeze().cpu()
        else:
            raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy', 'pil' and 'tensor'")
    

你可能感兴趣的:(python)