研究背景
检测任务
项目代码
Swin-Transformer-Object-Detection code
学习参考(Swin-Transformer源码(已跑通)
环境配置
可在已有mmDetection link 环境基础上进行配置
name: py37pt15
channels:
- pytorch
- psi4
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- blas=1.0=mkl
- ca-certificates=2021.4.13=h06a4308_1
- certifi=2020.12.5=py37h06a4308_0
- cloog=0.18.0=0
- cudatoolkit=10.1.243=h6bb024c_0
- cudnn=7.6.5=cuda10.1_0
- cython=0.29.23=py37h2531618_0
- freetype=2.10.4=h5ab3b9f_0
- gcc-5=5.2.0=1
- gmp=6.2.1=h2531618_2
- intel-openmp=2020.2=254
- isl=0.12.2=0
- jpeg=9b=h024ee3a_2
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.33.1=h53a641e_7
- libffi=3.3=he6710b0_2
- libgcc=7.2.0=h69d50b8_2
- libgcc-ng=9.1.0=hdf63c60_0
- libpng=1.6.37=hbc83047_0
- libstdcxx-ng=9.1.0=hdf63c60_0
- libtiff=4.1.0=h2733197_1
- lz4-c=1.9.3=h2531618_0
- mkl=2020.2=256
- mkl-service=2.3.0=py37he8ac12f_0
- mkl_fft=1.3.0=py37h54f3939_0
- mkl_random=1.1.1=py37h0573a6f_0
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- ncurses=6.2=he6710b0_1
- ninja=1.10.2=hff7bd54_1
- numpy=1.19.2=py37h54aff64_0
- numpy-base=1.19.2=py37hfa32c7d_0
- olefile=0.46=py37_0
- openssl=1.1.1k=h27cfd23_0
- pillow=8.2.0=py37he98fc37_0
- pip=21.0.1=py37h06a4308_0
- python=3.7.10=hdb3f193_0
- pytorch=1.5.0=py3.7_cuda10.1.243_cudnn7.6.3_0
- readline=8.1=h27cfd23_0
- setuptools=52.0.0=py37h06a4308_0
- six=1.15.0=py37h06a4308_0
- sqlite=3.35.4=hdfb4753_0
- tk=8.6.10=hbc83047_0
- torchvision=0.6.0=py37_cu101
- wheel=0.36.2=pyhd3eb1b0_0
- xz=5.2.5=h7b6447c_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.9=haebb681_0
- pip:
- addict==2.4.0
- cycler==0.10.0
- future==0.18.2
- kiwisolver==1.3.1
- matplotlib==3.4.1
- mmcv-full==1.3.1
- mmpycocotools==12.0.3
- opencv-python==4.5.1.48
- pyparsing==2.4.7
- python-dateutil==2.8.1
- pyyaml==5.4.1
- terminaltables==3.1.0
- timm==0.4.5
- yapf==0.31.0
prefix: /home/intern2/anaconda3/envs/py37pt15
其中apex可选。
训练测试过程
与mmdetection基本一致
训练命令
python tools/train.py configs_rib/swin/cascade_mask_rcnn_swin_tiny_rib.py --gpu-ids=7 --cfg-options model.pretrained=./checkpoints/cascade_mask_rcnn_swin_tiny_patch4_window7.pth --work-dir ./work_dirs/cascade_mask_rcnn_swin_rib0425_0506
python tools/train.py configs_rib/swin/cascade_mask_rcnn_swin_tiny_rib.py --gpu-ids=0 --cfg-options model.pretrained=./checkpoints/swin_tiny_patch4_window7_224.pth --work-dir=./work_dirs/cascade_mask_rcnn_swin_rib0425_0506
从链接 https://github.com/SwinTransformer/Swin-Transformer-Object-Detection下载的pretrained model 会有问题,建议从链接 https://github.com/microsoft/Swin-Transformer 下载 swin_tiny_patch4_window7_224.pth 预训练模型。
问题梳理
- 训练启动后关于 backbone registry 的 KeyError的问题
问题描述
KeyError: "CascadeRCNN: 'SwinTransformer is not in the backbone registry'"
解决方式:
在当前工程项目文件夹下运行如下命令
python setup.py develop
参考 issue 9
- 训练启动后关于 relative_position_bias_table 的 KeyError的问题
问题描述:
KeyError: "CascadeRCNN: 'backbone.layers.0.blocks.0.attn.relative_position_bias_table'"
用的预训练模型是在COCO det上微调过的模型,而不是ImageNet预训练中的模型。
解决方式;
从链接 https://github.com/microsoft/Swin-Transformer 下载相对应的模型。
参考 issue 4
- 训练启动后关于初始化的 RuntimeError 的问题
问题描述:
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.
问题原因是非分布式训练使用了分布式训练的设置
解决方式:
方案一是改为默认的分布式训练
# multi-gpu training
tools/dist_train.sh --cfg-options model.pretrained= [model.backbone.use_checkpoint=True] [other optional arguments]
方案二是修改 tools/train.py 中代码,加入如下内容:
import torch.distributed as dist
dist.init_process_group('gloo', init_method='file:///temp/somefile', rank=0, world_size=1)
方案三是修改配置文件 Swin-Transformer-Object-Detection/configs_rib/swin/cascade_mask_rcnn_swin_tiny.py 代码,将
norm_cfg=dict(type='SyncBN', requires_grad=True),
改为
norm_cfg=dict(type='BN', requires_grad=True),
即'SyncBN'改为'BN'。
'SyncBN'是采用distributed的训练方法,在单GPU non-distributed训练中使用会出现上述错误,改为type='BN' 即可。
- 训练时Apex报错,因而选择禁用
默认情况下,Swin使用apex进行混合精度训练,如果要禁用Apex,请修改Runner的类型为'EpochBasedRunner'并在配置文件中cascade_mask_rcnn_swin_tiny.py的修改并且注释以下代码块:
runner = dict(type='EpochBasedRunner', max_epochs=36)
## Disable apex
# # runner = dict(type='EpochBasedRunnerAmp', max_epochs=36)
# # do not use mmdet version fp16
# fp16 = None
# optimizer_config = dict(
# type="DistOptimizerHook",
# update_interval=1,
# grad_clip=None,
# coalesce=True,
# bucket_size_mb=-1,
# use_fp16=True,
# )
- 将默认的maskRCNN改为无mask的常规目标检测
需要将含mask的配置文件代码注释
首先是configs/swin/cascade_mask_rcnn_swin_tiny_rib.py
dict(type='LoadAnnotations', with_bbox=True), # remove mask
# dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), # remove mask
# dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
然后是configs/base/models/cascade_mask_rcnn_swin_fpn_rib.py
# mask_roi_extractor=dict(
# type='SingleRoIExtractor',
# roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
# out_channels=256,
# featmap_strides=[4, 8, 16, 32]),
# mask_head=dict(
# type='FCNMaskHead',
# num_convs=4,
# in_channels=256,
# conv_out_channels=256,
# num_classes=1,
# loss_mask=dict(
# type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))
# mask_size=28,
# mask_thr_binary=0.5
最后是修改训练数据集,将coco改为voc格式。
_base_ = [
'../_base_/models/cascade_mask_rcnn_swin_fpn_rib.py',
'../_base_/datasets/voc0712.py',
'../_base_/schedules/schedule_1x_rib.py', '../_base_/default_runtime.py'
]
参考 issue 25
- 使用自定义数据集进行训练
与mmdetection修改方式类似。
首先是修改配置文件configs/swin/cascade_mask_rcnn_swin_tiny.py 的类别个数
# num_classes=80,
num_classes=1,
然后是修改mmdet/core/evaluation/class_names.py的类别名
def voc_classes():
return [
'frac',
]
# return [
# 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
# 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
# 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
# ]
最后是mmdet/datasets/voc.py里的类别元组。
CLASSES = ('frac', )
# CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
# 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
# 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
# 'tvmonitor')
到此为止,环境的搭建、数据的准备、配置文件的修改基本准备完成,可以进行自定义数据集的训练过程。