阅读此文默认你已经可以成功运行SSD-TensorFlow的官方源码,也就是能够在VOC2007数据集上成功训练。下面将介绍一下如何用SSD-Tensorflow 训练自己的数据。如果不清楚的话可以参考链接:
https://github.com/balancap/SSD-Tensorflow
https://blog.csdn.net/ei1990/article/details/75282855
https://blog.csdn.net/liuyan20062010/article/details/78905517
明确2件事:
1)、训练自己的数据,我们选用的是什么模型?
2)、用什么模型来fine-tune?
1、我们在设置训练参数的时候,设置的model_name=ssd_300_vgg,就是我们要用ssd_300_vgg这个模型来训练我们的数据,当然我们的数据的类别跟模型默认的类别=21可能不一致,这就需要在训练的时候设置一下参数,num_classes=4,(如果搞清楚前面数据转换步骤这一点应该很容易理解),同理dataset_name要设置成自己的数据名。
2、我们这里选用vgg_16模型进行训练,意思就是利用vgg_16模型里的一些参数初始化我们的ssd_300_vgg模型里的部分参数,只训练其它层的参数,这就需要将指定哪些层不需要初始化,需要训练哪些层。
--checkpoint_exclude_scopes=ssd_300_vgg/conv6,ssd_300_vgg/conv7,ssd_300_vgg/block8,ssd_300_vgg/block9,ssd_300_vgg/block10,ssd_300_vgg/block11,ssd_300_vgg/block4_box,ssd_300_vgg/block7_box,ssd_300_vgg/block8_box,ssd_300_vgg/block9_box,ssd_300_vgg/block10_box,ssd_300_vgg/block11_box \
--trainable_scopes=ssd_300_vgg/conv6,ssd_300_vgg/conv7,ssd_300_vgg/block8,ssd_300_vgg/block9,ssd_300_vgg/block10,ssd_300_vgg/block11,ssd_300_vgg/block4_box,ssd_300_vgg/block7_box,ssd_300_vgg/block8_box,ssd_300_vgg/block9_box,ssd_300_vgg/block10_box,ssd_300_vgg/block11_box \
这里直接复制这两行就可以了。
最后附上我的训练文件,vgg_16.ckpt官网下载路径
#!/bin/bash
DATASET_DIR=./tfrecords ###训练集转化成tfrecords存储的路径
TRAIN_DIR=./logs/ ###存储训练结果的路径,包括checkpoint和event,自行指定
CHECKPOINT_PATH=./checkpoints/vgg_16.ckpt ###下载vgg_16模型
python train_ssd_network.py \
--train_dir=${TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=xxxxxx \ ###具体指定,改为your_data_name, 如果你在前面搞清楚了如何转换自己的数据的话
--dataset_split_name=train \
--model_name=ssd_300_vgg \
--num_classes=4
--checkpoint_path=${CHECKPOINT_PATH} \
--checkpoint_model_scope=vgg_16 \ ####改为vgg_16
--checkpoint_exclude_scopes=ssd_300_vgg/conv6,ssd_300_vgg/conv7,ssd_300_vgg/block8,ssd_300_vgg/block9,ssd_300_vgg/block10,ssd_300_vgg/block11,ssd_300_vgg/block4_box,ssd_300_vgg/block7_box,ssd_300_vgg/block8_box,ssd_300_vgg/block9_box,ssd_300_vgg/block10_box,ssd_300_vgg/block11_box \
--trainable_scopes=ssd_300_vgg/conv6,ssd_300_vgg/conv7,ssd_300_vgg/block8,ssd_300_vgg/block9,ssd_300_vgg/block10,ssd_300_vgg/block11,ssd_300_vgg/block4_box,ssd_300_vgg/block7_box,ssd_300_vgg/block8_box,ssd_300_vgg/block9_box,ssd_300_vgg/block10_box,ssd_300_vgg/block11_box \
--save_summaries_secs=60 \
--save_interval_secs=600 \
--weight_decay=0.0005 \
--optimizer=adam \
--learning_rate=0.001 \
--learning_rate_decay_factor=0.94 \
--batch_size=32
这里只是简单的介绍一下,如介绍的不清楚还请见谅,有问题可以联系我。