TensorFlow 对象检测 API 教程 4

TensorFlow 对象检测 API 教程 - 第4部分:训练模型

在本教程中,认为已经选择了预先训练的模型,找到了现有的数据集或创建了自己的数据集,并将其转换为 TFRecord文件。现在准备好训练自己模型。


一. 模型配置文件

如果你以前有转移学习的经验,可能会产生一个自从本教程第二部分以来一直徘徊的问题。那个问题是,如何修改被设计为在 COCO 数据集的90个类上工作的预先训练的模型,以处理新数据集的 X 个类?要在 object detection API 之前完成,必须删除网络的最后 90 个神经元分类层,并将其替换为新的图层。下面显示了 TensorFlow 中的一个示例。


# Assume fc_2nd_last is the 2nd_last fully connected layer in your network and nb_classes is the number of classes in your new dataset.
shape = (fc_2nd_last.get_shape().as_list()[-1], nb_classes)
fc_last_W = tf.Variable(tf.truncated_normal(shape, stddev=1e-2))
fc_last_b = tf.Variable(tf.zeros(nb_classes))
logits = tf.nn.xw_plus_b(fc_2nd_last, fc_last_W, fc_last_b)

要使用 object detection API 来实现这一点,只需修改模型配置文件中的一行代码即可。在克隆 TensorFlow models1 的位置,进入到 object_detection/samples/configs 目录。在此文件夹中,可以找到所有预先训练的模型的配置文件。

复制所选模型的配置文件,并将其移动到一个新文件夹,并在其中执行所有训练。在这个新文件夹中,创建一个名为 data 的文件夹并将 TFRecord 文件移动到其中。创建另一个名为 models 的文件夹,并将所选择的预训练模型的 .ckpt (检查点)文件(其中3个)移动到此文件夹中。回想一下,model_detection_zoo.md 包含每个预先训练的模型的下载链接,这里的每个模型的下载将不仅包含 .pb 文件(在教程第1部分的 jupyter notebook 中使用过),还包含 .ckpt 文件。在 models 文件夹内创建另一个名为 train 的文件夹。


二. 修改配置文件

在文本编辑器中打开新移动的配置文件,在最上面将类的数量更改为数据集中的数量。接下来,将 fine_tune_checkpoint 的路径更改为指向 model.ckpt 文件。如果遵循模型结构,建议改为:

fine_tune_checkpoint: "models/model.ckpt"

参数 num_steps 决定在完成之前将要运行多少个训练步骤。这个数字实际上取决于数据集的大小以及其他因素(包括让模型训练的时间)。一旦开始训练,建议先看看每个训练步骤需要多长时间,并相应地调整 num_steps

接下来,需要更改训练数据集和评估数据集的 input_pathlabel_map_pathInput_path 只是到自己的 TFRecord 文件。在可以设置 label_map_path 的路径之前,需要创建它应该指向的文件。它所要查找的是一个 .pbtxt 文件,其中包含数据集每个标签的 ID名称。可以按照以下格式在任何文本文件中创建此文件。


item {
  id: 1
  name: 'Green'
} 
item {
  id: 2
  name: 'Red'
}

确保从 id:1 开始,而不是 0。 建议把这个文件放在自己的数据文件夹中。最后将 num_examples 设置为拥有的评估样本的数量。


三. 训练

进入 object_detection 文件夹并将 train.py 复制到新创建的培训文件夹中。要开始训练,只需将终端窗口导航到此文件夹(确保已按照教程第1部分中的安装说明操作),然后在命令行中输入


python train.py --logtostderr --train_dir=./models/train --pipeline_config_path=rfcn_resnet101_coco.config

pipline_config_path 指向配置文件。现在开始培训。当心,根据你的系统,培训可能需要几分钟的时间才能开始,所以如果它没有崩溃或停止,请给它更多的时间。

如果计算机内存不足会导致训练的失败,可以尝试多种解决方案。首先尝试添加参数

batch_queue_capacity: 2
prefetch_queue_capacity: 2

train_config 部分的配置文件。例如,将两行放在 gradient_clipping_by_normfine_tune_checkpoint 之间。上面的数字 2 应该只是开始训练的开始值。这些值的默认值分别是 810 ,增加这些值应该有助于加速训练。

就是这样,现在已经开始训练,这将能够调整模型!如果想更好地了解训练的进展情况,可以考虑使用TensorBoard 。

在接下来的文章将讲述说明如何保存所训练的模型,并在项目部署了!

你可能感兴趣的:(TensorFlow 对象检测 API 教程 4)