下面将会使用VGG16为基础,来微调(Fine-tune)模型达到训练我们自己的数据的目的。这里将会分类一些地表的卫星图片来区分森林、水域、岩石、农田、冰川和城市区域。数据集已经上传至CSDN:https://download.csdn.net/download/viafcccy/11791071
一、数据集
这里需要了解一下python在命令行下的参数解析
1.from __future__ import absolute_import理解 https://blog.csdn.net/viafcccy/article/details/101061413
2.argparse库 https://blog.csdn.net/viafcccy/article/details/101061661
# coding:utf-8
from __future__ import absolute_import
import argparse
import os
import logging
from src.tfrecord import main
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tensorflow-data-dir', default='pic/')
parser.add_argument('--train-shards', default=2, type=int)
parser.add_argument('--validation-shards', default=2, type=int)
parser.add_argument('--num-threads', default=2, type=int)
parser.add_argument('--dataset-name', default='satellite', type=str)
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
args = parse_args()
args.tensorflow_dir = args.tensorflow_data_dir
args.train_directory = os.path.join(args.tensorflow_dir, 'train')
args.validation_directory = os.path.join(args.tensorflow_dir, 'validation')
args.output_directory = args.tensorflow_dir
args.labels_file = os.path.join(args.tensorflow_dir, 'label.txt')
if os.path.exists(args.labels_file) is False:
logging.warning('Can\'t find label.txt. Now create it.')
all_entries = os.listdir(args.train_directory)
dirnames = []
for entry in all_entries:
if os.path.isdir(os.path.join(args.train_directory, entry)):
dirnames.append(entry)
with open(args.labels_file, 'w') as f:
for dirname in dirnames:
f.write(dirname + '\n')
main(args)#将args作为参数传入src.tfrecord的main()函数执行
在命令行切换到当前的目录,输入
python data_convert.py -t pic/ --train-shards 2 --validation-shards 2 --num-threads 2 --dataset-name satellite
此时会发现几个问题都是由于python2和3之间的兼容问题导致的
于是我们需要打开产生问题的文件 一个是random.py 一个是tfrecord.py
1. File "E:\Deep-Learning-21-Examples\chapter_3\data_prepare\src\tfrecord.py", line 341, in _find_image_files
random.shuffle(shuffled_index)
File "C:\ProgramData\Anaconda3\lib\random.py", line 275, in shuffle
x[i], x[j] = x[j], x[i]
TypeError: 'range' object does not support item assignment
将random中产生随机数的
random.shuffle(shuffled_index)
=>random.shuffle(list(shuffled_index))
原因:是python3中range不返回数组对象,而是返回range对象。
2. File "C:\Users\Administrator\Desktop\DL\21\Deep-Learning-21-Examples-master\chapter_3\data_prepare\src\tfrecord.py", line 273, in _process_image_files
for thread_index in xrange(len(ranges)):
NameError: name 'xrange' is not defined
将tfrecord.py中所有的xrange改为range
原因:在Python 3中,range()与xrange()合并为range( )。
剩下类似的问题还有:
tfrecord.py第160行改为 with open(filename, 'rb') as f:
tfrecord.py第94和96行修改为 colorspace = b'RGB' image_format = b'JPEG'
tfrecord.py第104行修改为 'image/class/text': _bytes_feature(str.encode(text)),
tfrecord.py第106行修改为 'image/filename':_bytes_feature(os.path.basename(str.encode(filename))),
执行完成后产生的几个数据集文件 分别是 训练数据集 测试数据集 标签
二、使用 tensorflow slim 微调模型
下载源码
与cifar10中一样使用git命令或者直接访问https://github.com/tensorflow/models下载model文件夹 打开research文件夹
关于更多的细节可以看https://www.cnblogs.com/bmsl/p/dongbin_bmsl_01.html
同时下载http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz的官方模型
下载完成后
在slim文件夹中创建三个文件夹
将ckpt模型放入pretrained中
将前面生成的五个数据文件放入data
打开slim文件夹新建一个satellite.py文件作为数据库的配置文件
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset.
The dataset scripts used to create the dataset can be found at:
tensorflow/models/slim/datasets/download_and_convert_flowers.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from datasets import dataset_utils
slim = tf.contrib.slim
_FILE_PATTERN = 'satellite_%s_*.tfrecord'
SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200}
_NUM_CLASSES = 6
_ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying size.',
'label': 'A single integer between 0 and 4',
}
def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
"""Gets a dataset tuple with instructions for reading flowers.
Args:
split_name: A train/validation split name.
dataset_dir: The base directory of the dataset sources.
file_pattern: The file pattern to use when matching the dataset sources.
It is assumed that the pattern contains a '%s' string so that the split
name can be inserted.
reader: The TensorFlow reader type.
Returns:
A `Dataset` namedtuple.
Raises:
ValueError: if `split_name` is not a valid train/validation split.
"""
if split_name not in SPLITS_TO_SIZES:
raise ValueError('split name %s was not recognized.' % split_name)
if not file_pattern:
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
# Allowing None in the signature so that dataset_factory can use the default.
if reader is None:
reader = tf.TFRecordReader
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir)
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
num_classes=_NUM_CLASSES,
labels_to_names=labels_to_names)
修改数据库的配置文件 dataset_factory.py 在其中注册我们自己的数据集
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A factory-pattern class which returns classification image/label pairs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import satellite
datasets_map = {
'cifar10': cifar10,
'flowers': flowers,
'imagenet': imagenet,
'mnist': mnist,
'satellite': satellite,
}
def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None):
"""Given a dataset name and a split_name returns a Dataset.
Args:
name: String, the name of the dataset.
split_name: A train/test split name.
dataset_dir: The directory where the dataset files are stored.
file_pattern: The file pattern to use for matching the dataset source files.
reader: The subclass of tf.ReaderBase. If left as `None`, then the default
reader defined by each dataset is used.
Returns:
A `Dataset` class.
Raises:
ValueError: If the dataset `name` is unknown.
"""
if name not in datasets_map:
raise ValueError('Name of dataset unknown %s' % name)
return datasets_map[name].get_split(
split_name,
dataset_dir,
file_pattern,
reader)
在命令行输入以下命令
python train_image_classifier.py ^
--train_dir=satellite/train_dir ::日志 cheakpoint 保存的目录 ^
--dataset_name=satellite :: 数据集名称 ^
--dataset_split_name=train ::数据集文件夹名称 ^
--dataset_dir=satellite/data ::保存训练模型的路径 ^
--model_name=inception_v3 ::使用模型的名称 ^
--checkpoint_path=satellite/pretrained/inception_v3.ckpt ::预训练模型的保存位置 ^
--checkpoint_exclude_scopes-InceptionV3/Logits,InceptionV3/AuxLogits ::预训练不恢复的两层 ^
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits ::训练中改变参数权重的层 ^
--max_number_of_steps=100000 ::最大执行次数 ^
--batch_size=32 ::一批的数据量 ^
--learning_rate=0.001 ::学习率 ^
--learning_rate_decay_type=fixed ::是否自动更新学习率 ^
--save_interval_secs=300 ::300s保存一次模型 ^
--save_surmaries_secs=2 ::2s更新一次日志 ^
--log_every_n_steps=10 ::每10步打印一次训练情况 ^
--optimizer=rmsprop ::优化器 ^
--weight_decay=0.00004 ::二次正则化超参数
这里值得注意的是 这里我们为了减少计算量因此选择 只训练f8 也就是输出层
要训练全部参数 只需要去掉train_scopes参数即可