Deeplab v3 (1): 源码训练和测试

本文主要介绍根据github tensorflow/models中官方代码来训练deeplab v3+

源代码: https://github.com/tensorflow/models/tree/master/research/deeplab

配置deeplab v3

  1. Clone 源代码, https://github.com/tensorflow/models.git
  2. 根据官方文档进行安装,https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/installation.md
    这里有几个需要注意的地方:
    (1) cuda 9.0 & tensorflow 1.6以上版本
    (2) 需要将models/research/slim路径导入到PYTHONPATH环境变量中,这个是因为deeplab中的一些工具比如multigrid使用的是slim中实现的
    export PYTHONPATH=$PYTHONPATH:/path-to/models/research/slim
    (3) 然后进行测试就可以得到一个结果,测试需要在models/research/下进行,这样很不方便,想要在models/research/deeplab下进行,可以在models/research/deeplab/model_test.py中导入deeplab模块就可以了,一个例子如下:
import sys
sys.path.append('/path-to/models/research')

测试成功,则表示基本配置成功,可以进行训练的配置了.

训练

参考官方文档: https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/cityscapes.md
以Cityscapes为例进行训练

  1. 将Cityscapes数据转为tfrecord,使用models/research/deeplab/datasets下的脚本: convert_cityscapes.sh、build_cityscapes_data.py、build_data.py脚本,这个需要改一下convert_cityscapes.sh中的一些路径,基本没有什么坑。这三个文件比较简单,可以读一下,之后做数据集基本就基于这三个文件了.
  2. 使用脚本进行训练
CHECKPOINT_PATH='/path-to/models/research/deeplab/initial-checkpoint/xception/model.ckpt'
TRAIN_DIR_PATH='/path-tomodels/research/deeplab/train_dir'
CITYSCAPES_PATH='/path-to/cityscapes/tfrecord'

python train.py \
    --logtostderr \
    --training_number_of_steps=90000 \
    --train_split="train" \
    --model_variant="xception_65" \
    --atrous_rates=6 \
    --atrous_rates=12 \
    --atrous_rates=18 \
    --output_stride=16 \
    --decoder_output_stride=4 \
    --train_crop_size=769 \
    --train_crop_size=769 \
    --train_batch_size=1 \
    --dataset="cityscapes" \
    --tf_initial_checkpoint="${CHECKPOINT_PATH}" \
    --train_logdir="${TRAIN_DIR_PATH}" \
    --dataset_dir="${CITYSCAPES_PATH}"
  1. 评估/测试
CHECKPOINT_PATH='/path-to/models/research/deeplab/initial-checkpoint/deeplabv3_cityscapes_train'
EVAL_DIR_PATH='/path-to/models/research/deeplab/eval_dir'
CITYSCAPES_PATH='/path-to/cityscapes/tfrecord'

python eval.py \
    --logtostderr \
    --eval_split="val" \
    --model_variant="xception_65" \
    --atrous_rates=6 \
    --atrous_rates=12 \
    --atrous_rates=18 \
    --output_stride=16 \
    --decoder_output_stride=4 \
    --eval_crop_size=1025 \
    --eval_crop_size=2049 \
    --dataset="cityscapes" \
    --checkpoint_dir="${CHECKPOINT_PATH}" \
    --eval_logdir="${EVAL_DIR_PATH}" \
    --dataset_dir="${CITYSCAPES_PATH}"

注意,如果使用官方提供的checkpoint,压缩包中是没有checkpoint文件的,需要手动添加一个checkpoint文件
4. 性能
根据官方提供的checkpoint
(1) official-deeplabv3+, tensorflow。eval OS: 16, scale: [1.0]
miou: 0.787332237
(2) official-deeplabv3+, tensorflow。eval OS: 16, scale: [0.75:0.25:1.75]
miou: 0.806650937

注意

由于是第一次跑tf,免不了有很多的坑
1. tf默认占用所有GPU的所有计算资源,通常可能如果只想使用其中一个,或者即可,可以在需要执行的脚本前面加上: CUDA_VISIBLE_DEVICES=gpu_id,即可
2. 如果使用官方提供的checkpoint,压缩包中是没有checkpoint文件的,需要手动添加一个checkpoint文件

你可能感兴趣的:(计算机视觉代码,图像语义分割代码)