P2BNet网络训练VisDrone数据集

P2BNet网络训练VisDrone数据集

文章目录

  • P2BNet网络训练VisDrone数据集
    • 0 P2BNet介绍
    • 1 VisDrone数据集转COCO格式
    • 2 P2BNet训练VisDrone数据集
        • Ⅰ 选定训练模型
        • Ⅱ 下载模型权重,修改全连接层神经元个数
        • Ⅲ 更改类别参数
        • Ⅳ 生成点注释json文件
        • Ⅴ 训练

0 P2BNet介绍

该网络源自ECCV2022最新的目标检测论文Point-to-Box Network for Accurate Object Detection via Single Point Supervision

该论文介绍了一种单点监督的改进方法,作者认为使用单点监督的OTSP方法之所以性能一直都不太好,是因为他们生成建议包的质量普遍太差,所以作者提供了一种轻量级的网络——P2BNet,来弥补过往方式的缺陷,从而生成高质量的实例级包,节约标准成本,性能SOTA!

论文阅读可参考博客:http://t.csdn.cn/kAa8H

代码地址:GitHub - ucas-vg/P2BNet: ECCV2022, Point-to-Box Network for Accurate Object Detection via Single Point Supervision

1 VisDrone数据集转COCO格式

VisDrone是一个无人机的目标检测数据集。由各种无人机摄像头捕获,包含不同场景、天气、光照条件下的各种位置、环境、物体等。

VisDrone转COCO格式可参考博客:http://t.csdn.cn/xYEN7

注意将visdrone数据集的目录格式调整为coco的目录格式,并将转化好的json文件放入annotations文件夹中。

2 P2BNet训练VisDrone数据集

Ⅰ 选定训练模型

​ 这里采用的是faster_rcnn_r50_fpn_1x_coco.py模型,位于P2BNet-main/TOV_mmdetection/configs/faster_rcnn/目录下

Ⅱ 下载模型权重,修改全连接层神经元个数

​ 权重下载地址:mmdetection/configs/faster_rcnn at master · open-mmlab/mmdetection · GitHub

​ 修改全连接层神经元个数代码:

import torch
pretrained_weights  = torch.load('checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth')
#更改权重载入地址
num_class = 10
pretrained_weights['state_dict']['roi_head.bbox_head.fc_cls.weight'].resize_(num_class+1, 1024)
pretrained_weights['state_dict']['roi_head.bbox_head.fc_cls.bias'].resize_(num_class+1)
pretrained_weights['state_dict']['roi_head.bbox_head.fc_reg.weight'].resize_(num_class*4, 1024)
pretrained_weights['state_dict']['roi_head.bbox_head.fc_reg.bias'].resize_(num_class*4)
torch.save(pretrained_weights, "faster_rcnn_r50_fpn_1x_%d.pth"%num_class)

​ 该代码参考博客:http://t.csdn.cn/xYEN7

Ⅲ 更改类别参数

visdrone数据集的数据类别有:ignored regions, pedestrian,people, bicycle, car, van, truck, tricycle, awning-tricycle, bus, motor,others共12种(一般会去掉头和尾,剩下10种类别),需要修改的地方有:

①P2BNet-main/TOV_mmdetection/mmdet/datasets/coco.py 中的Classes:

P2BNet网络训练VisDrone数据集_第1张图片

②P2BNet-main/TOV_mmdetection/mmdet/core/evaluation/class_names.py中的coco_classes:

P2BNet网络训练VisDrone数据集_第2张图片

③P2BNet-main/TOV_mmdetection/configs2/COCO/P2BNet/中的num_classes(80->12):

P2BNet网络训练VisDrone数据集_第3张图片

注意,这里的类别数如果与之前的类别集合数目不一致的话,可能在训练过程中(而不是训练一开始)会出现RuntimeError: CUDA error: device-side assert triggered的错误。

Ⅳ 生成点注释json文件

生成点注释的python文件位于:P2BNet-main/TOV_mmdetection/huicv/corner_dataset/corner_dataset_util.py

注意这个python文件传入的参数,要在执行该文件时,填入相应的路径参数,其中:

ann_file 是visdrone转化好的annotation.json文件,save_ann_file是保存点注释文件的路径,这里的路径应与annotation.json在同一目录下。

注意:json文件里面没存放ignore字段,所以要在同目录下的coner_utils.py文件中将此块代码注释掉,如下:

P2BNet网络训练VisDrone数据集_第4张图片

Ⅴ 训练

训练命令参考:P2BNet/README.md at main · ucas-vg/P2BNet · GitHub

/github.com/ucas-vg/P2BNet/blob/main/TOV_mmdetection/README.md)

你可能感兴趣的:(目标检测)