在自己的数据集上训练一个新的深度学习模型时,一般采取在预训练ImageNet上进行微调的方法。什么是微调?这里以VGG16为例进行讲解。
图1.1VGG16结构示意图
如图1.1所示,VGG16的结构为卷积层+全连接层。卷积层分为五个部分,共13层,即图中的conv1~conv5。还有全连接层fc6、fc7、fc8。卷积层加上全连接层总共16层,因此被称为VGG16。如果要将VGG16的结构用于一个新的数据集,首先要去掉fc8这一层,因为fc8这一层输入的是fc7这一层的特征,输出是1000类的概率,这1000类正好对应的是ImageNet模型中的1000个类别,在自己的数据集中一般不是1000类,因此fc8是不适用的,必须将fc8去掉,重新采用适合数据集类别的全连接层,作为新的fc8全连接层,比如数据集为6类,那么新的fc8的输出应当是6类。
此外,在训练的时候,网络参数的初始值并不是随机生成的,而是采用VGG16在ImageNet上已经训练好的参数作为训练的初始值。这样做的原因在于,在ImageNet数据集上训练过的VGG16中的参数已经包含了大量有用的卷积过滤器,与其从零开始初始化VGG16的所有参数,不如使用自己已经训练好的参数作为训练的起点。这样做不仅节约了大量的训练时间,而且还有助于分类器性能的提高。
载入VGG16的参数后,就可以开始训练了,此时需要指定训练层数的范围,一般可以选择以下几种范围进行训练:
a. 只训练fc8。训练范围一定要包含fc8这一层,前面讲过,fc8结构被调整过,因此它的参数不能直接从ImageNet预训练模型中取得,可以只训练fc8层,保持其它层的参数不变,这样就相当于将VGG16当做一个“特征提取器”:用fc7层提取的特征做一个softmax模型分类。这样做的好处是训练速度快,但往往性能不会太好。
b. 训练所有参数。还可以对网络中所有的参数进行训练,这种方法的训练速度可能会比较慢,但是能提取较高的性能,可以充分发挥深度模型的威力。
c. 训练部分参数。通常是固定浅层参数不变,训练浅层参数。如训练conv1、conv2的部分参数不变,只训练conv3、conv4、conv5、fc6、fc7、fc8。
这种训练方法就是所谓的对神经网络模型做微调,借助微调可以从预训练模型出发,将神经网络应用到自己的数据集上。
2数据准备
首先要做一些数据准备方面的工作:一是把数据集切分为训练集和验证集,二是转换为tfrecord格式。在data_prepare/文件夹中提供了会用到的数据集和代码。
首先要将自己的数据集切分为训练集和验证集。验证集用于验证模型的准确率,本博客用了一个实验的卫星图片分类数据集,这个数据集一共有六个类。点击打开数据集链接。
在data_prepare文件夹下,使用预先编译好的脚本data_convert.py,将图片转换为tfrecord格式。
python data_convert.py -t pic/ \
--train-shards 2 \
--validation-shards 2 \
--num-threads 2 \
--dataset-name satellite
这样在pic文件夹下就会生成4个tfrecord文件和1个label.txt文件。
3定义新的dataset文件
首先,在dataset/目录下新建一个文件夹satellite.py,并将flowers.py文件夹中的内容复制到satellite.py中,接下来需要修改以下几处内容。
第一处修改,
对应修改如下图所示
第二处修改修改为image/format部分
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
修改完satellite.py后,还需要在同目录的dataset_factory.py文件夹中注册satellite数据库。红色框内为新增加的satellite数据
4准备训练文件夹
在slim文件夹下新建一个satellite目录,在这个目录下完成以下工作:
a 新建一个data 目录,并将第2中准备好的5个转换好格式的训练数据复制进去。
b 新建一个train_dir目录,用来保存训练过程中的日志和模型。
c 新建一个pretrained目录,在http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz 下载并解压后,会得到一个inception_v3.ckpt 文件,将该文件复制到pretrained 目录下。
5开始训练
在slim文件夹下运行以下命令开始训练
python train_image_classifier.py \
--train_dir=satellite/train_dir \
--dataset_name=satellite \
--dataset_split_name=train \
--dataset_dir=satellite/data \
--model_name=inception_v3 \
--checkpoint_path=satellite/pretrained/inception_v3.ckpt \
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--max_number_of_steps=100000 \
--batch_size=32 \
--learning_rate=0.001 \
--learning_rate_decay_type=fixed \
--save_interval_secs=300 \
--save_summaries_secs=2 \
--log_every_n_steps=10 \
--optimizer=rmsprop \
--weight_decay=0.00004
6验证模型准确率
可以用eval_image_classifier.py程序进行验证,在slim文件夹下运行以下程序
python eval_image_classifier.py \
--checkpoint_path=satellite/train_dir \
--eval_dir=satellite/eval_dir \
--dataset_name=satellite \
--dataset_split_name=validation \
--dataset_dir=satellite/data \
--model_name=inception_v3
执行后,应该会出现类似下面的结果
7导出模型,并对单张图片进行识别
在slim文件夹下运行以下程序
python export_inference_graph.py \
--alsologtostderr \
--model_name=inception_v3 \
--output_file=satellite/inception_v3_inf_graph.pb \
--dataset_name satellite
这个命令会在satellite文件夹下生成一个inception_v3_inf_graph.pb文件。(注:inception_v3_inf_graph.pb文件夹只保存了inception_v3的网络结构并不包含训练得到的模型,需要将checkpoint中的模型参数保存进来。需将12106改成train_dir中保存的实际的模型训练步数)在chapter_3文件夹下运行以下命令
python freeze_graph.py \
--input_graph slim/satellite/inception_v3_inf_graph.pb \
--input_checkpoint slim/satellite/train_dir/model.ckpt-12106 \
--input_binary true \
--output_node_names InceptionV3/Predictions/Reshape_1 \
--output_graph slim/satellite/frozen_graph.pb
运行导出模型分类单张图片
python classify_image_inception_v3.py \
--model_path slim/satellite/frozen_graph.pb \
--label_path data_prepare/pic/label.txt \
--image_file test_image.jpg
本博客是我在学习何之源的"21个项目玩转深度学习”这本书时,跟着教材第三章所做的一些实例,实验数据均为第三章的数据,读者朋友如发现错误,有疑问,请留言。谢谢~