【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》

论文地址:https://arxiv.org/abs/2007.09384#:~:text=Multi-Scale%20Positive%20Sample%20Refinement%20for%20Few-Shot%20Object%20Detection.,previous%20attempts%20that%20exploit%20few-shot%20classification%20techniques%20
代码地址:https://github.com/jiaxi-wu/MPSR

我的配置:
Python :3.6.5(ubuntu20.04)
Pytorch :1.9.0
Cuda :11.1
GPU:RTX 3090 Ti(24GB)

1、依赖安装

conda install ipython pip

pip install ninja yacs cython matplotlib tqdm opencv-python

pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html

2、安装pycocotools

创建文件夹 install_dir
【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第1张图片

cd install_dir

git clone https://github.com/cocodataset/cocoapi.git

cd cocoapi/PythonAPI

python setup.py build_ext install

pycocotools编译成功:
【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第2张图片

3、安装apex

cd install_dir

git clone https://github.com/NVIDIA/apex.git

cd apex

git checkout 96b017a

python setup.py install --cuda_ext --cpp_ext

apex编译成功:
【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第3张图片

4、build

python setup.py build develop

报错:

/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/include/ATen/core/TensorBody.h:303:30: note: declared here
  303 |   DeprecatedTypeProperties & type() const {
      |                              ^~~~
ninja: build stopped: subcommand failed.
Traceback (most recent call last):
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 1673, in _run_ninja_build
    env=env)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/subprocess.py", line 418, in run
    output=stdout, stderr=stderr)
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "setup.py", line 68, in <module>
    cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/setuptools/__init__.py", line 153, in setup
    return distutils.core.setup(**attrs)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/distutils/core.py", line 148, in setup
    dist.run_commands()
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/distutils/dist.py", line 955, in run_commands
    self.run_command(cmd)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/distutils/dist.py", line 974, in run_command
    cmd_obj.run()
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/distutils/command/build.py", line 135, in run
    self.run_command(cmd_name)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/distutils/cmd.py", line 313, in run_command
    self.distribution.run_command(command)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/distutils/dist.py", line 974, in run_command
    cmd_obj.run()
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/setuptools/command/build_ext.py", line 79, in run
    _build_ext.run(self)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/Cython/Distutils/old_build_ext.py", line 186, in run
    _build_ext.build_ext.run(self)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/distutils/command/build_ext.py", line 339, in run
    self.build_extensions()
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 708, in build_extensions
    build_ext.build_extensions(self)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/Cython/Distutils/old_build_ext.py", line 195, in build_extensions
    _build_ext.build_ext.build_extensions(self)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/distutils/command/build_ext.py", line 448, in build_extensions
    self._build_extensions_serial()
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/distutils/command/build_ext.py", line 473, in _build_extensions_serial
    self.build_extension(ext)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/setuptools/command/build_ext.py", line 202, in build_extension
    _build_ext.build_extension(self, ext)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/distutils/command/build_ext.py", line 533, in build_extension
    depends=ext.depends)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 538, in unix_wrap_ninja_compile
    with_cuda=with_cuda)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 1359, in _write_ninja_file_and_compile_objects
    error_prefix='Error compiling objects for extension')
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 1683, in _run_ninja_build
    raise RuntimeError(message) from e
RuntimeError: Error compiling objects for extension

解决方案:

https://github.com/facebookresearch/maskrcnn-benchmark/issues/1274
https://github.com/amazon-science/siam-mot/blob/main/readme/INSTALL.md

cuda_dir="maskrcnn_benchmark/csrc/cuda"
perl -i -pe 's/AT_CHECK/TORCH_CHECK/' $cuda_dir/deform_pool_cuda.cu $cuda_dir/deform_conv_cuda.cu
# You can then run the regular setup command
python setup.py build develop

build成功:
在这里插入图片描述

4、准备数据集

4.1、VOC数据集

mkdir -p datasets/voc
wget https://pjreddie.com/media/files/VOCtrainval_11-May-2012.tar
wget https://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar
wget https://pjreddie.com/media/files/VOCtest_06-Nov-2007.tar
tar xf VOCtrainval_11-May-2012.tar
tar xf VOCtrainval_06-Nov-2007.tar
tar xf VOCtest_06-Nov-2007.tar

将VOC2007和VOC2012放在 datasets/voc 文件夹下:
【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第4张图片

4.2、准备基数据集和小样本数据集

bash tools/fewshot_exp/datasets/init_fs_dataset_standard.sh

在这里我没有使用COCO数据集,所以把tools/fewshot_exp/datasets/init_fs_dataset_standard.sh文件中的第7、8行注释掉了
【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第5张图片报错:

正克隆到 '../Fewshot_Detection'...
remote: Enumerating objects: 365, done.
remote: Counting objects: 100% (7/7), done.
remote: Compressing objects: 100% (7/7), done.
remote: Total 365 (delta 2), reused 1 (delta 0), pack-reused 358
接收对象中: 100% (365/365), 121.58 KiB | 496.00 KiB/s, 完成.
处理 delta 中: 100% (267/267), 完成.
Traceback (most recent call last):
  File "tools/fewshot_exp/datasets/voc_create_base.py", line 1, in <module>
    from maskrcnn_benchmark.data.datasets.voc import PascalVOCDataset
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/data/__init__.py", line 2, in <module>
    from .build import make_data_loader
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/data/build.py", line 8, in <module>
    from maskrcnn_benchmark.utils.imports import import_file
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/utils/imports.py", line 4, in <module>
    if torch._six.PY3:
AttributeError: module 'torch._six' has no attribute 'PY3'
Traceback (most recent call last):
  File "tools/fewshot_exp/datasets/voc_create_standard.py", line 2, in <module>
    from maskrcnn_benchmark.data.datasets.voc import PascalVOCDataset
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/data/__init__.py", line 2, in <module>
    from .build import make_data_loader
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/data/build.py", line 8, in <module>
    from maskrcnn_benchmark.utils.imports import import_file
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/utils/imports.py", line 4, in <module>
    if torch._six.PY3:
AttributeError: module 'torch._six' has no attribute 'PY3'
Traceback (most recent call last):
  File "tools/fewshot_exp/datasets/coco_create_base.py", line 1, in <module>
    from maskrcnn_benchmark.data.datasets.coco import COCODataset
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/data/__init__.py", line 2, in <module>
    from .build import make_data_loader
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/data/build.py", line 8, in <module>
    from maskrcnn_benchmark.utils.imports import import_file
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/utils/imports.py", line 4, in <module>
    if torch._six.PY3:
AttributeError: module 'torch._six' has no attribute 'PY3'
Traceback (most recent call last):
  File "tools/fewshot_exp/datasets/coco_create_standard.py", line 1, in <module>
    from maskrcnn_benchmark.data.datasets.coco import COCODataset
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/data/__init__.py", line 2, in <module>
    from .build import make_data_loader
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/data/build.py", line 8, in <module>
    from maskrcnn_benchmark.utils.imports import import_file
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/utils/imports.py", line 4, in <module>
    if torch._six.PY3:
AttributeError: module 'torch._six' has no attribute 'PY3'

解决方案:

https://blog.csdn.net/pangweijian/article/details/120371802

完成:
【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第6张图片

5、对 VOC 数据集进行小样本训练

5.1、对 3 个 VOC splits进行基础训练

下载Resnet-101权重文件,将 configs/fewshot/base 文件夹和configs/fewshot_baseline/base文件夹下的三个yaml配置文件中的WEIGHTS路径改为Resnet-101权重文件的路径:
【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第7张图片

根据自己的电脑配置修改tools/fewshot_exp/train_voc_base.sh文件中GPU的数量和GPU编号:

【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第8张图片

基础训练:

bash tools/fewshot_exp/train_voc_base.sh

1、报错:

Traceback (most recent call last):
  File "./tools/train_net.py", line 197, in <module>
    main()
  File "./tools/train_net.py", line 190, in main
    model = train(cfg, args.local_rank, args.distributed)
  File "./tools/train_net.py", line 65, in train
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/utils/checkpoint.py", line 61, in load
    checkpoint = self._load_file(f)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/utils/checkpoint.py", line 134, in _load_file
    return load_c2_format(self.cfg, f)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/utils/c2_model_loading.py", line 206, in load_c2_format
    return C2_FORMAT_LOADER[cfg.MODEL.BACKBONE.CONV_BODY](cfg, f)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/utils/c2_model_loading.py", line 192, in load_resnet_c2_format
    state_dict = _load_c2_pickled_weights(f)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/utils/c2_model_loading.py", line 135, in _load_c2_pickled_weights
    if torch._six.PY3:
AttributeError: module 'torch._six' has no attribute 'PY3'
Killing subprocess 58533

解决方案:
找到报错的文件,将PY3改为PY37:
【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第9张图片

2、报错:

2023-04-10 18:29:41,611 maskrcnn_benchmark.trainer INFO: Start training
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/multiprocessing/resource_sharer.py", line 139, in _serve
    signal.pthread_sigmask(signal.SIG_BLOCK, range(1, signal.NSIG))
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/signal.py", line 60, in pthread_sigmask
    sigs_set = _signal.pthread_sigmask(how, mask)
  ValueError: signal number 32 out of range

解决方案:
修改maskrcnn_benchmark/config/defaults.py文件中_C.DATALOADER.NUM_WORKERS = 0,然后重新运行build.py文件进行编译:
【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第10张图片

3、报错

Traceback (most recent call last):
  File "./tools/train_net.py", line 197, in <module>
    main()
  File "./tools/train_net.py", line 190, in main
    model = train(cfg, args.local_rank, args.distributed)
  File "./tools/train_net.py", line 96, in train
    data_loader_closeup
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/engine/trainer.py", line 78, in do_train
    loss_dict = model(images, targets, closeups, closeup_targets)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/apex-0.1-py3.6-linux-x86_64.egg/apex/amp/_initialize.py", line 197, in new_fwd
    **applier(kwargs, input_caster))
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py", line 70, in forward
    proposals, proposal_losses = self.rpn(images, features, targets, closeup_rpn_features)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/modeling/rpn/rpn.py", line 189, in forward
    return self._forward_train(anchors, objectness, rpn_box_regression, targets, closeup_objectness)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/modeling/rpn/rpn.py", line 208, in _forward_train
    anchors, objectness, rpn_box_regression, targets, closeup_objectness
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/modeling/rpn/loss.py", line 114, in __call__
    sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/maskrcnn_benchmark-0.1-py3.6-linux-x86_64.egg/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py", line 50, in __call__
    perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
RuntimeError: radix_sort: failed on 1st step: cudaErrorInvalidDevice: invalid device ordinal
Killing subprocess 61318
Traceback (most recent call last):
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/launch.py", line 340, in <module>
    main()
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/launch.py", line 326, in main
    sigkill_handler(signal.SIGTERM, None)  # not coming back
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/launch.py", line 301, in sigkill_handler
    raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)
subprocess.CalledProcessError: Command '['/home/test/anaconda3/envs/mpsr/bin/python', '-u', './tools/train_net.py', '--local_rank=0', '--config-file', 'configs/fewshot/base/e2e_voc_split1_base.yaml']' returned non-zero exit status 1.

解决方案:
将pytorch1.8.0换成1.9.0,然后重新编译apex以及MPSR整个项目

pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html

4、报错

cannot import name ‘container_abcs‘ from ‘torch._six‘

找到报错的位置,把
from torch._six import container_abcs
改为:
import collections.abc as container_abcs
【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第11张图片

5、报错

    "__main__", mod_spec)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/launch.py", line 173, in <module>
    main()
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/launch.py", line 169, in main
    run(args)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/run.py", line 624, in run
    )(*cmd_args)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/launcher/api.py", line 116, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
    return f(*args, **kwargs)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/launcher/api.py", line 238, in launch_agent
    result = agent.run()
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper
    result = f(*args, **kwargs)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/elastic/agent/server/api.py", line 700, in run
    result = self._invoke_run(role)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/elastic/agent/server/api.py", line 822, in _invoke_run
    self._initialize_workers(self._worker_group)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper
    result = f(*args, **kwargs)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/elastic/agent/server/api.py", line 670, in _initialize_workers
    self._rendezvous(worker_group)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper
    result = f(*args, **kwargs)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/elastic/agent/server/api.py", line 530, in _rendezvous
    store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py", line 60, in next_rendezvous
    self.timeout,
RuntimeError: Address already in use

解决方案:
https://github.com/facebookresearch/maskrcnn-benchmark/issues/241

训练完成,得到三个权重文件:
【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第12张图片

5.2、微调

依旧是,根据自己的电脑配置修改tools/fewshot_exp/train_voc_standard.sh文件中GPU的数量和GPU编号
【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第13张图片
依旧是,参考这个修改train_voc_standard.sh文件:
https://github.com/facebookresearch/maskrcnn-benchmark/issues/241

微调:

bash tools/fewshot_exp/train_voc_standard.sh

1、报错

Traceback (most recent call last):
  File "tools/train_net.py", line 197, in <module>
    main()
  File "tools/train_net.py", line 190, in main
    model = train(cfg, args.local_rank, args.distributed)
  File "tools/train_net.py", line 65, in train
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
  File "/home/test/code/MPSR/maskrcnn_benchmark/utils/checkpoint.py", line 62, in load
    self._load_model(checkpoint)
  File "/home/test/code/MPSR/maskrcnn_benchmark/utils/checkpoint.py", line 98, in _load_model
    load_state_dict(self.model, checkpoint.pop("model"))
  File "/home/test/code/MPSR/maskrcnn_benchmark/utils/model_serialization.py", line 80, in load_state_dict
    model.load_state_dict(model_state_dict)
  File "/home/test/anaconda3/envs/mpsr/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1407, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))

        size mismatch for roi_heads.box.predictor.cls_score.weight: copying a param with shape torch.Size([16, 1024]) from checkpoint, the shape in current model is torch.Size([21, 1024]).
        size mismatch for roi_heads.box.predictor.cls_score.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([21]).

解决方案:
https://github.com/jiaxi-wu/MPSR/issues/9

修改完成之后运行以下命令重新生成’voc0712_split%dbase_pretrained.pth’就行了

python tools/fewshot_exp/trans_voc_pretrained.py 1
python tools/fewshot_exp/trans_voc_pretrained.py 2
python tools/fewshot_exp/trans_voc_pretrained.py 3

微调完成,得到这些:
【代码调试】《Multi-scale Positive Sample Refinement for Few-shot Object Detection》_第14张图片

5.3、评估

3 split的1/2/3/5/10 shot进行评估。
默认情况下,存储结果文件夹为 fs_exp/voc_standard_results,可以通过以下方式快速获得评估结果:

python tools/fewshot_exp/cal_novel_voc.py fs_exp/voc_standard_results

6、结论

官方代码使用2个GPU,而我只有1个GPU。按照作者提供的单卡训练方法修改了诸多参数之后跑出来的效果仍旧远远不及官方效果,放在这里给大家做一个参考,不过具体数值就不列出了:基础训练mAP在0.2~0.4之间,微调后的mAP在0.1~0.2之间。
如有错误,欢迎指正。

你可能感兴趣的:(小样本目标检测,代码调试,目标检测,深度学习,小样本)