一、原理
在自己的数据集上训练一个新的深度学习模型时,一般采取在预训练ImageNet上进行微调的方法。什么是微调?这里以VGG16为例进行讲解。
VGG16的结构为卷积+全连接层。卷积层分为5个部分共13层,即conv1~conv5。还有三层全连接层,即fc6、fc7、fc8。卷积层加上全连接层合起来一共为16层。如果要将VGG16的结构用于一个新的数据集,首先要去掉fc8这一层。原因是fc8层的输入是fc7层的特征,输出是1000类的概率,这1000类正好对应了ImageNet模型中的1000个类别。在自己的数据中,类别数一般不是1000类,因此fc8层的结构在此时是不适用的。必须将fc8层去掉,重新采用符合数据集类别数的全连接层,作为新的fc8.比如数据集为5类,那么新的fc8的输出也应当是5类。
此外,在训练的时候,网络的参数的初始值并不是随机化生成的,而是采用VGG16在ImageNet上已经训练好的参数作为训练的初始值。这样做的原因在于,在ImageNet数据集上训练过的VGG16的参数已经包含了大量有用的卷积过滤器,与其从零开始初始化VGG16的所有参数,不如使用已经训练好的参数当作训练的起点。这样做不仅可以节约大量训练时间,而且有助于分类起性能的提高。
载入VGG16的参数后,就可以开始训练了。此时需要指定训练层数的范围。一般来说,可以选择以下几种范围进行训练:
二、数据集准备
将jpg格式样本集合转化为tfrecord格式。
三、使用Tensorflow Slim微调模型
slim是google公司公布的一个图像分类工具包,不仅定义了一些方便的接口,还提供了很多ImageNet数据集上常用的网络结构和预训练模型。包括VGG16\VGG19、Inception v1~v4、ResNet 50、ResNet101、MobileNet在内大多数常用模型的结构以及预训练模型,更多的模型会被持续添加进来。
1)下载Tensorflow Slim的源代码
git clone https://github.com/tensorflow/models.git
找到models/research/slim文件夹。
2)定义新的datasets文件
在slim/datasets中,定义了所有可以使用的数据库,为了使用我之前创建的tfrecord数据进行训练,必须要在datasets中定义新的数据库如handGesturePic。
首先在datasets/目录下新建一个文件handGesturePic.py,并将flowers.py文件中的内容复制到handGesturePic.py中。然后修改以下几处内容。
_FILE_PATTERN='handGesturePic_%s_*.tfrecord'//改成自己的图片的命名
SPLITS_TO_SIZES={‘train’:9488,'validation':2000}//训练集和测试集的总数目
_NUM_CLASSES=2
第二处修改:image/format部分
‘image/format’:tf.FixedLenFeature((),tf.string,default_value='jpg').//定义图片的默认格式。
修改完handGesturePic.py之后,还需要在同目录的data_factory.py文件中注册handGesturePic数据库。
添加以下内容:from datasets import handGesturePic
datasets_map={
’ cifarlO ’: cifarlO,
’ flowers ’: flowers,
’ image net ’: imagenet,
’ mnist ’: mnist,
‘handGesturePic’:handGesturePic,}
3)准备训练文件夹
在slim中新建 handGesturePic目录,在这个目录中进行以下操作:
新建一个data目录,将之前生成的5个转换好的训练数据复制进去(4个.tfrecord,1个label.txt)。
新建一个空的train_dir目录,用来保存训练过程中的日志和模型。
新建一个pretrained目录,在slim的GitHubi页面找到Inception-V3模型的下载地址http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz,下载并解压后,得到 inception_v3.ckpt文件,将该文件复制到pretrained目录下。
4)开始训练
在slim文件夹下,运行以下命令就可以开始训练了:
python train_image_calssifier.py –train_dir=handGesturePic/train_dir –dataset_name=handGesturePic –dataset_split_name=train –dataset_dir=handGesturePic/data –model_name=inception_v3 –checkpoint_path=handGesturePic/pretrained/inception_v3.ckpt –checkpoint_exclude_scopes=InceptionV3/Logits, InceptionV3/AuxLogits –trainable_scopes=InceptionV3/Logits, InceptionV3/AuxLogits –max_number_of_steps=10000 –batch_size=32 –learning_rate=0.001 –learning _rate_decay_type=fixed –save_interval_secs=300 –save_summaries_secs=2 –log_every_n_steps=10 –optimizer=rmsprop –weight_decay=0.00004
参数解释:
trainable_scopes=InceptionV3/Logits, InceptionV3/AuxLogits。trainable_scopes规定了在模型中微调变量的范围。这里的设定表示只对 InceptionV3/Logits, InceptionV3/AuxLogits两个变量进行微调,其他变量都保持不动。 InceptionV3/Logits, InceptionV3/AuxLogits是inception V3的末端层。只对最后一层分类层进行训练,比如原来是1000类,现在训练的只是2类。如果不设定trainable_scopes,就只会对模型中所有的参数进行训练。
5)验证模型准确率
执行脚本:python eval_image_classifier.py –checkpoint_path=handGesturePic/train_dir –eval_dir=handGesturePic/eval_dir –dataset_name=handGesturePic –dataset_split_name=validation –dataset_dir=handGesturePic/data –model_name=inception_v3
修改eval_image_classifier.py中’ Accuracy': slim.metrics.streaming_accuracy(predicti。ns, labels),
’ Recall_S ’: slim.metrics.streaming_reca ll_at_k(
logits, labels, 5),//确定输出前几个的准确率,因为我只有2类,所以改为'1'
模型的准确率和召回率均为98%。
6)Tensorboard可视化
命令:tensorboard –logdir handGesturePic/train_dir
可以看到损失变化的曲线。当损失曲线比较平缓,收敛较慢时,可以考虑增大学习率,以加快收敛速度;如果损失曲线波动较大,无法收敛,就可能是学习率过大,此时就可以尝试适当减少学习率。
7)导出模型并对单张图片进行识别
首先在slim文件夹下运行:
python export_inference_graph.py –-alsologtostderr --model_name=inception_v3 –output_file=handGesturePic/inception_v3_inf_graph.pb –dataset_name handGesturePic
这个命令会在handGesturePic文件夹生成一个inception_v3_inf_graph.pb文件。该文件只保存了inception v3的网络结构,并不包含训练得到的模型参数。需要将checkpoint中的模型参数保存进来。
Python freeze_graph.py –input_graph handGesturePic/inception_v3_inf_graph.pb –input_checkpoint handGesturePic/train_dir/model.ckpt-5000 –input_binary true –output_node_names InceptionV3/Predictions/Reshape_1 –output_graph handGesturePic/frozen_graph.pb
如何使用导出的frozen_graph.pb来对单张图片进行预测?编写一个classify_image_inception_v3.py脚本来完成这件事:
Python classify_image_inception_v3.py –model_path handGesturePic/frozen_graph.pb –label_path data_prepare/Pic/label.txt –image_file test_image.jpg