如何使用TensorFlow在自己的图像数据上训练深度学习模型

1微调原理

在自己的数据集上训练一个新的深度学习模型时,一般采取在预训练ImageNet上进行微调的方法。什么是微调?这里以VGG16为例进行讲解。

如何使用TensorFlow在自己的图像数据上训练深度学习模型_第1张图片

图1.1VGG16结构示意图

如图1.1所示,VGG16的结构为卷积层+全连接层。卷积层分为五个部分,共13层,即图中的conv1~conv5。还有全连接层fc6、fc7、fc8。卷积层加上全连接层总共16层,因此被称为VGG16。如果要将VGG16的结构用于一个新的数据集,首先要去掉fc8这一层,因为fc8这一层输入的是fc7这一层的特征,输出是1000类的概率,这1000类正好对应的是ImageNet模型中的1000个类别,在自己的数据集中一般不是1000类,因此fc8是不适用的,必须将fc8去掉,重新采用适合数据集类别的全连接层,作为新的fc8全连接层,比如数据集为6类,那么新的fc8的输出应当是6类。

此外,在训练的时候,网络参数的初始值并不是随机生成的,而是采用VGG16在ImageNet上已经训练好的参数作为训练的初始值。这样做的原因在于,在ImageNet数据集上训练过的VGG16中的参数已经包含了大量有用的卷积过滤器,与其从零开始初始化VGG16的所有参数,不如使用自己已经训练好的参数作为训练的起点。这样做不仅节约了大量的训练时间,而且还有助于分类器性能的提高。

载入VGG16的参数后,就可以开始训练了,此时需要指定训练层数的范围,一般可以选择以下几种范围进行训练:

a. 只训练fc8。训练范围一定要包含fc8这一层,前面讲过,fc8结构被调整过,因此它的参数不能直接从ImageNet预训练模型中取得,可以只训练fc8层,保持其它层的参数不变,这样就相当于将VGG16当做一个“特征提取器”:用fc7层提取的特征做一个softmax模型分类。这样做的好处是训练速度快,但往往性能不会太好。

b. 训练所有参数。还可以对网络中所有的参数进行训练,这种方法的训练速度可能会比较慢,但是能提取较高的性能,可以充分发挥深度模型的威力。 

c训练部分参数。通常是固定浅层参数不变,训练浅层参数。如训练conv1、conv2的部分参数不变,只训练conv3、conv4、conv5、fc6、fc7、fc8。

这种训练方法就是所谓的对神经网络模型做微调,借助微调可以从预训练模型出发,将神经网络应用到自己的数据集上。

2数据准备

首先要做一些数据准备方面的工作:一是把数据集切分为训练集和验证集,二是转换为tfrecord格式。在data_prepare/文件夹中提供了会用到的数据集和代码。

首先要将自己的数据集切分为训练集和验证集。验证集用于验证模型的准确率,本博客用了一个实验的卫星图片分类数据集,这个数据集一共有六个类。点击打开数据集链接。

在data_prepare文件夹下,使用预先编译好的脚本data_convert.py,将图片转换为tfrecord格式。

python data_convert.py -t pic/ \
  --train-shards 2 \
  --validation-shards 2 \
  --num-threads 2 \
  --dataset-name satellite

这样在pic文件夹下就会生成4个tfrecord文件和1个label.txt文件。

如何使用TensorFlow在自己的图像数据上训练深度学习模型_第2张图片

3定义新的dataset文件

首先,在dataset/目录下新建一个文件夹satellite.py,并将flowers.py文件夹中的内容复制到satellite.py中,接下来需要修改以下几处内容。

第一处修改,

如何使用TensorFlow在自己的图像数据上训练深度学习模型_第3张图片

对应修改如下图所示

如何使用TensorFlow在自己的图像数据上训练深度学习模型_第4张图片

第二处修改修改为image/format部分

'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),

修改完satellite.py后,还需要在同目录的dataset_factory.py文件夹中注册satellite数据库。红色框内为新增加的satellite数据

如何使用TensorFlow在自己的图像数据上训练深度学习模型_第5张图片

4准备训练文件夹

在slim文件夹下新建一个satellite目录,在这个目录下完成以下工作:

 新建一个data 目录,并将第2中准备好的5个转换好格式的训练数据复制进去。

新建一个train_dir目录,用来保存训练过程中的日志和模型。

新建一个pretrained目录,在http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz 下载并解压后,会得到一个inception_v3.ckpt 文件,将该文件复制到pretrained 目录下。

5开始训练

在slim文件夹下运行以下命令开始训练

python train_image_classifier.py \
  --train_dir=satellite/train_dir \
  --dataset_name=satellite \
  --dataset_split_name=train \
  --dataset_dir=satellite/data \
  --model_name=inception_v3 \
  --checkpoint_path=satellite/pretrained/inception_v3.ckpt \
  --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --max_number_of_steps=100000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=300 \
  --save_summaries_secs=2 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
  --weight_decay=0.00004
如何使用TensorFlow在自己的图像数据上训练深度学习模型_第6张图片

6验证模型准确率      

可以用eval_image_classifier.py程序进行验证,在slim文件夹下运行以下程序

python eval_image_classifier.py \
  --checkpoint_path=satellite/train_dir \
  --eval_dir=satellite/eval_dir \
  --dataset_name=satellite \
  --dataset_split_name=validation \
  --dataset_dir=satellite/data \
  --model_name=inception_v3

执行后,应该会出现类似下面的结果

如何使用TensorFlow在自己的图像数据上训练深度学习模型_第7张图片

7导出模型,并对单张图片进行识别

在slim文件夹下运行以下程序

python export_inference_graph.py \
  --alsologtostderr \
  --model_name=inception_v3 \
  --output_file=satellite/inception_v3_inf_graph.pb \
  --dataset_name satellite

这个命令会在satellite文件夹下生成一个inception_v3_inf_graph.pb文件。(注:inception_v3_inf_graph.pb文件夹只保存了inception_v3的网络结构并不包含训练得到的模型,需要将checkpoint中的模型参数保存进来。需将12106改成train_dir中保存的实际的模型训练步数)在chapter_3文件夹下运行以下命令

python freeze_graph.py \
  --input_graph slim/satellite/inception_v3_inf_graph.pb \
  --input_checkpoint slim/satellite/train_dir/model.ckpt-12106 \
  --input_binary true \
  --output_node_names InceptionV3/Predictions/Reshape_1 \
  --output_graph slim/satellite/frozen_graph.pb

运行导出模型分类单张图片

python classify_image_inception_v3.py \
  --model_path slim/satellite/frozen_graph.pb \
  --label_path data_prepare/pic/label.txt \
  --image_file test_image.jpg

如何使用TensorFlow在自己的图像数据上训练深度学习模型_第8张图片


本博客是我在学习何之源的"21个项目玩转深度学习”这本书时,跟着教材第三章所做的一些实例,实验数据均为第三章的数据,读者朋友如发现错误,有疑问,请留言。谢谢~

你可能感兴趣的:(深度学习)