SimpleBaseline Tensorflow2 代码

环境需求:

Python3.6

Cuda 11.2 及其 Cudnn

tensorflow-gpu==2.5( > 2.0 都可,只是当中会存在一个地方需要进行修改)

opencv

pycocotools

imutils

完整项目:​​​​​​百度网盘

介绍

整体介绍

SimpleBaseline Tensorflow2 代码_第1张图片

   整体项目的结构如图所示

SimpleBaseline_COCO

        -Datasets                           数据集处理相关代码

                -coco_TopDown.py           用于产生train和val的txt文件

                -DataAugmentation.py      数据增强

                -datasets.py                      数据生成器,用于加载数据进行训练

        -log                                     log文件,训练的时候会记录一些指标存在这里

        -Metrics                              计算精度相关代码

                -compute_ap.py                 计算AP

        -Model                                模型文件

                -function_net.py                  存放网络模型的总文件

                -net_all.py                           调用并返回网络模型

        -Other                                 其他文件

                -utils.py                                一些处理和其他功能函数

        -Result                                结果文件

                -best_save.hdf5                   这是我训练了20周期后的结果,精度45左右

                -resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5  imagenet上的预训练权重

                -result.json                            用于评估精度时产生的结果文件

        -config.yaml                       初始配置

        -requirements.txt               环境需求 

        -test.py                              测试GPU能够正常使用

        -train.py                             执行训练

数据集准备

当前只训练COCO数据集

请确保文件是以下的结构

"""
COCO/
    -annotations
        -person_keypoints_train2017.json
        -person_keypoints_val2017.json
    -train2017
        -
        ...
        -
    -val2017
        -
        ...
        -
"""

SimpleBaseline Tensorflow2 代码_第2张图片

代码配置及执行

1. config.yaml进行配置

SimpleBaseline Tensorflow2 代码_第3张图片

 在config.yaml中,需要做一些修改,

coco_path 指向 coco 数据集的根目录

weight_save_path 指向 项目中的best_save.hdf5文件路径

predict_save_json 指向 项目中的Result文件,注意后面加 ' \ '

pre_train_path 指向 imagenet预训练权重目录

2.  代码执行

步骤一:        运行 Datasets 目录下 coco_TopDown.py 文件,运行后会在主目录下产生train.txt 和 val.txt 文件

步骤二(可选):       运行 Metrics 目录下的 compute_ap.py (需在最下方修改config.yaml的路径, 以及 val 路径SimpleBaseline Tensorflow2 代码_第4张图片)文件,即可输出精度,修改config.yaml 中的 show 变量,即可更改是否展示图片效果。

步骤三:                      运行 train.py 文件, 即可开始训练 ,并在每一个epoch后测量精度

你可能感兴趣的:(Tensorflow2,关键点检测,tensorflow,python,深度学习)