安装retinanet_PyTorch 实现 RetinaNet 目标检测

这篇文章介绍一个 PyTorch 实现的 RetinaNet 实现目标检测。文章的思想来自论文:Focal Loss for Dense Object Detection。这个实现的主要目标是为了方便读者能够很好的理解和更改源代码。

 

结果

当前的实现能达到 33.7% 的 mAP(600px 分辨率,Resnet-50)。论文里的结果是 34.0% mAP,造成这个差别的主要原因可能是这里使用了 Adam 优化器,而论文里使用了 SGD 和 weight decay。

安装

1. 用 Git 克隆 https://github.com/yhenon/pytorch-retinanet

2. 安装必备包:

apt-get install tk-dev python-tk

1apt-getinstalltk-devpython-tk

3. 安装 Python 包:

pip install cffi

pip install pandas

pip install pycocotools

pip install cython

pip install pycocotools

pip install opencv-python

pip install requests

1

2

3

4

5

6

7pipinstallcffi

pipinstallpandas

pipinstallpycocotools

pipinstallcython

pipinstallpycocotools

pipinstallopencv-python

pipinstallrequests

4. 编译 NMS 扩展.

cd pytorch-retinanet/lib

bash make.sh

cd ../

1

2

3cdpytorch-retinanet/lib

bashmake.sh

cd../

怎样训练

训练主要用 train.py 文件。现在可用的训练数据有两个: COCO 和 CSV。

要训练 COCO:

python train.py --dataset coco --coco_path ../coco --depth 50

1pythontrain.py--datasetcoco--coco_path../coco--depth50

如果要训练自己的数据集,要用 CSV 格式:

python train.py --dataset csv --csv_train --csv_classes --csv_val

1pythontrain.py--datasetcsv--csv_train--csv_classes--csv_val

预训练模型

可以在这里下载预训练模型:链接。

项目地址

项目在 Github 上,点击访问。

本站微信群、QQ群(三群号 726282629):

你可能感兴趣的:(安装retinanet)