DETR训练自己的数据集

  • github地址:https://github.com/facebookresearch/detr
1.创建conda环境

推荐通过conda创建虚拟环境,具体操作可见linux系统下创建anaconda新环境及问题解决

2.clone代码并安装依赖库
git clone https://github.com/facebookresearch/detr.git

conda install -c pytorch pytorch torchvision
conda install cython scipy
pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
pip install git+https://github.com/cocodataset/panopticapi.git
3.准备自己的数据集
  • 使用COCO格式数据集,其文件目录如下:
    在这里插入图片描述
  • 其中,annotations包含训练集和验证集对应的json文件
    在这里插入图片描述
  • train2017包含训练集图片;val2017包含验证集图片
    ※如果事先准备好了VOC格式的数据集,则可通过脚本进行转换,详见VOC格式数据集转为COCO格式数据集脚本
4.下载预训练模型并修改类别参数
  • 创建一个python文件,根据自己的目标类别数目对原始用于coco数据集的预训练模型进行转换
import torch

pretrained_weights = torch.load("./detr-r50-e632da11.pth")

num_class = 2 + 1
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1,256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)

torch.save(pretrained_weights,'detr_r50_%d.pth'%num_class)
  • 更改detr.py中的目标类别数目(这里干脆都改成一样的了)
    DETR训练自己的数据集_第1张图片
5.训练
  • 运行main.py并传递相应的参数进行训练
python main.py --dataset_file "coco" --coco_path /path/to/coco/ --resume="detr_r50_3.pth"
6.推理
  • 同样运行main.py 需要--eval及其它相关参数
7.plot
  • 借助plot_utils.py,在文件末尾添加下方代码,更改路径并运行即可。
if __name__ == '__main__':
    files = list(Path('../outputs/eval').glob('*.pth'))
    plot_precision_recall(files)
    plt.show()
    plot_logs(logs=Path('../outputs/log/'),fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt')
    plt.show()
  • 参考链接:
    https://blog.csdn.net/w1520039381/article/details/118905718

你可能感兴趣的:(深度学习,pytorch,目标检测,计算机视觉,深度学习)