最近transform在cv圈可谓hot到爆,记录一下之前参加比赛所用的代码swin-transform,算法来源于ICCV2020的best paper。废话少说,直接上源码
code:https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
paper:https://openaccess.thecvf.com/content/ICCV2021/papers/Liu_Swin_Transformer_Hierarchical_Vision_Transformer_Using_Shifted_Windows_ICCV_2021_paper.pdf
搭建swin-transform需要的环境
1.conda create -n swin python=3.7
2.conda activate swin
3.conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=11.0
4.pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.0/index.html
5.pip install mmdet
本教程所用的mmdetection版本:code链接https://github.com/open-mmlab/mmdetection/releases/tag/v2.18.0
1.执行mmdetection下的setup.py编译环境
至此环境的事基本上弄完了
2.由于我们目标只有一类,需要修改mmdetction下相关文件,新建一个data文件存放数据集,数据集按coco数据集格式准备。

找到configs文件下_base_文件、打开_base_文件下dataset文件里的coco_detection.py |
|
|
|
对应code
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]

找到configs文件下_base_文件、打开_base_文件下models如图所示:本文选用mask_rcnn_r50_fpn.py作为backbone。修改mask_rcnn_r50_fpn.py里的num_class |
|
|
|
对应code
model = dict(
type='MaskRCNN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
修改mmdet
找到mmdet下core里的evaluation文件。打开找到class_names.py 修改数据集的label |
|
|
|

def coco_classes():
return [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign',
'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard',
'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'
]
def coco_classes():
return['你自己的label',]
找到mmdet下datasets下的coco.py,如图 |
|
|
|

class CocoDataset(CustomDataset):
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
开始训练
python tools/train.py configs/swin/mask_rcnn_swin-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py

跑起来了,开始训练后,训练日志在mmdetection下work_dir里。 |
|
|
|
调用训练好的模型测试
import asyncio
from argparse import ArgumentParser
from mmdet.apis import (async_inference_detector, inference_detector,
init_detector, show_result_pyplot)
def parse_args():
parser = ArgumentParser()
parser.add_argument('img', default='image root', help='Image file')
parser.add_argument('config',default='config root', help='Config file')
parser.add_argument('checkpoint',default='save model root', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
parser.add_argument(
'--async-test',
action='store_true',
help='whether to set async options for async inference.')
args = parser.parse_args()
return args
def main(args):
model = init_detector(args.config, args.checkpoint, device=args.device)
result = inference_detector(model, args.img)
show_result_pyplot(model, args.img, result, score_thr=args.score_thr)
async def async_main(args):
model = init_detector(args.config, args.checkpoint, device=args.device)
tasks = asyncio.create_task(async_inference_detector(model, args.img))
result = await asyncio.gather(tasks)
show_result_pyplot(model, args.img, result[0], score_thr=args.score_thr)
if __name__ == '__main__':
args = parse_args()
if args.async_test:
asyncio.run(async_main(args))
else:
main(args)

完了