我的数据集是卫星图片,共5类:
wood/
water/
rock/
wetland/
glacier/
urban/
1,是把数据集切分为训练集和验证集
我的训练集结构:
pic/
train/
wood/
water/
rock/
wetland/
glacier/
urban/
validation/
wood/
water/
rock/
wetland/
glacier/
urban/
将图片分为train和validation两个目录,分别表示训练使用的图片和验证使用的图片。
在每个目录中,分别以类别名为文件夹名保存所有图像。
在每个类别文件夹下,存放的就是原始的图像(如JPg格式的图像文件)。
2,将图像转换为tfrecord格式。
from __future__ import absolute_import
import argparse
import os
import logging
from src.tfrecord import main
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tensorflow-data-dir', default='pic/')
parser.add_argument('--train-shards', default=2, type=int)
parser.add_argument('--validation-shards', default=2, type=int)
parser.add_argument('--num-threads', default=2, type=int)
parser.add_argument('--dataset-name', default='satellite', type=str)
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
args = parse_args()
args.tensorflow_dir = args.tensorflow_data_dir
args.train_directory = os.path.join(args.tensorflow_dir, 'train')
args.validation_directory = os.path.join(args.tensorflow_dir, 'validation')
args.output_directory = args.tensorflow_dir
args.labels_file = os.path.join(args.tensorflow_dir, 'label.txt')
if os.path.exists(args.labels_file) is False:
logging.warning('Can\'t find label.txt. Now create it.')
all_entries = os.listdir(args.train_directory)
dirnames = []
for entry in all_entries:
if os.path.isdir(os.path.join(args.train_directory, entry)):
dirnames.append(entry)
with open(args.labels_file, 'w') as f:
for dirname in dirnames:
f.write(dirname + '\n')
main(args)
解析:
main函数所在的文件:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import os
import random
import sys
import threading
import numpy as np
import tensorflow as tf
import logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def _int64_feature(value):
"""Wrapper for inserting int64 features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
value = tf.compat.as_bytes(value) # 这行需要添加
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _convert_to_example(filename, image_buffer, label, text, height, width):
"""Build an Example proto for an example.
Args:
filename: string, path to an image file, e.g., '/path/to/example.JPG'
image_buffer: string, JPEG encoding of RGB image
label: integer, identifier for the ground truth for the network
text: string, unique human-readable, e.g. 'dog'
height: integer, image height in pixels
width: integer, image width in pixels
Returns:
Example proto
"""
colorspace = 'RGB'
channels = 3
image_format = 'JPEG'
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': _int64_feature(height),
'image/width': _int64_feature(width),
'image/colorspace': _bytes_feature(colorspace),
'image/channels': _int64_feature(channels),
'image/class/label': _int64_feature(label),
'image/class/text': _bytes_feature(text),
'image/format': _bytes_feature(image_format),
'image/filename': _bytes_feature(os.path.basename(filename)),
'image/encoded': _bytes_feature(image_buffer)}))
return example
class ImageCoder(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Create a single Session to run all image coding calls.
self._sess = tf.Session()
# Initializes function that converts PNG to JPEG data.
self._png_data = tf.placeholder(dtype=tf.string)
image = tf.image.decode_png(self._png_data, channels=3)
self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def png_to_jpeg(self, image_data):
return self._sess.run(self._png_to_jpeg,
feed_dict={self._png_data: image_data})
def decode_jpeg(self, image_data):
image = self._sess.run(self._decode_jpeg,feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _is_png(filename):
"""Determine if a file contains a PNG format image.
Args:
filename: string, path of the image file.
Returns:
boolean indicating if the image is a PNG.
"""
return '.png' in filename
def _process_image(filename, coder):
"""Process a single image file.
Args:
filename: string, path to an image file e.g., '/path/to/example.JPG'.
coder: instance of ImageCoder to provide TensorFlow image coding utils.
Returns:
image_buffer: string, JPEG encoding of RGB image.
height: integer, image height in pixels.
width: integer, image width in pixels.
"""
# Read the image file.
with open(filename, 'rb') as f:
image_data = f.read()
# Convert any PNG to JPEG's for consistency.
if _is_png(filename):
logging.info('Converting PNG to JPEG for %s' % filename)
image_data = coder.png_to_jpeg(image_data)
# Decode the RGB JPEG.
image = coder.decode_jpeg(image_data)
# Check that image converted to RGB
assert len(image.shape) == 3
height = image.shape[0]
width = image.shape[1]
assert image.shape[2] == 3
return image_data, height, width
def _process_image_files_batch(coder, thread_index, ranges, name, filenames,
texts, labels, num_shards, command_args):
"""Processes and saves list of images as TFRecord in 1 thread.
Args:
coder: instance of ImageCoder to provide TensorFlow image coding utils.
thread_index: integer, unique batch to run index is within [0, len(ranges)).
ranges: list of pairs of integers specifying ranges of each batches to
analyze in parallel.
name: string, unique identifier specifying the data set
filenames: list of strings; each string is a path to an image file
texts: list of strings; each string is human readable, e.g. 'dog'
labels: list of integer; each integer identifies the ground truth
num_shards: integer number of shards for this data set.
"""
# Each thread produces N shards where N = int(num_shards / num_threads).
# For instance, if num_shards = 128, and the num_threads = 2, then the first
# thread would produce shards [0, 64).
num_threads = len(ranges)
assert not num_shards % num_threads
num_shards_per_batch = int(num_shards / num_threads)
shard_ranges = np.linspace(ranges[thread_index][0],
ranges[thread_index][1],
num_shards_per_batch + 1).astype(int)
num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
counter = 0
for s in range(num_shards_per_batch):
# Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
shard = thread_index * num_shards_per_batch + s
output_filename = '%s_%s_%.5d-of-%.5d.tfrecord' % (command_args.dataset_name, name, shard, num_shards)
output_file = os.path.join(command_args.output_directory, output_filename)
writer = tf.python_io.TFRecordWriter(output_file)
shard_counter = 0
files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
for i in files_in_shard:
filename = filenames[i]
label = labels[i]
text = texts[i]
image_buffer, height, width = _process_image(filename, coder)
example = _convert_to_example(filename, image_buffer, label,
text, height, width)
writer.write(example.SerializeToString())
shard_counter += 1
counter += 1
if not counter % 1000:
logging.info('%s [thread %d]: Processed %d of %d images in thread batch.' %
(datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()
writer.close()
logging.info('%s [thread %d]: Wrote %d images to %s' %
(datetime.now(), thread_index, shard_counter, output_file))
sys.stdout.flush()
shard_counter = 0
logging.info('%s [thread %d]: Wrote %d images to %d shards.' %
(datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()
def _process_image_files(name, filenames, texts, labels, num_shards, command_args):
"""Process and save list of images as TFRecord of Example protos.
Args:
name: string, unique identifier specifying the data set
filenames: list of strings; each string is a path to an image file
texts: list of strings; each string is human readable, e.g. 'dog'
labels: list of integer; each integer identifies the ground truth
num_shards: integer number of shards for this data set.
"""
assert len(filenames) == len(texts)
assert len(filenames) == len(labels)
# Break all images into batches with a [ranges[i][0], ranges[i][1]].
spacing = np.linspace(0, len(filenames), command_args.num_threads + 1).astype(np.int)
ranges = []
for i in range(len(spacing) - 1):
ranges.append([spacing[i], spacing[i + 1]])
# Launch a thread for each batch.
logging.info('Launching %d threads for spacings: %s' % (command_args.num_threads, ranges))
sys.stdout.flush()
# Create a mechanism for monitoring when all threads are finished.
coord = tf.train.Coordinator()
# Create a generic TensorFlow-based utility for converting all image codings.
coder = ImageCoder()
threads = []
for thread_index in range(len(ranges)):
args = (coder, thread_index, ranges, name, filenames,
texts, labels, num_shards, command_args)
t = threading.Thread(target=_process_image_files_batch, args=args)
t.start()
threads.append(t)
# Wait for all the threads to terminate.
coord.join(threads)
logging.info('%s: Finished writing all %d images in data set.' %
(datetime.now(), len(filenames)))
sys.stdout.flush()
def _find_image_files(data_dir, labels_file, command_args):
"""Build a list of all images files and labels in the data set.
Args:
data_dir: string, path to the root directory of images.
Assumes that the image data set resides in JPEG files located in
the following directory structure.
data_dir/dog/another-image.JPEG
data_dir/dog/my-image.jpg
where 'dog' is the label associated with these images.
labels_file: string, path to the labels file.
The list of valid labels are held in this file. Assumes that the file
contains entries as such:
dog
cat
flower
where each line corresponds to a label. We map each label contained in
the file to an integer starting with the integer 0 corresponding to the
label contained in the first line.
Returns:
filenames: list of strings; each string is a path to an image file.
texts: list of strings; each string is the class, e.g. 'dog'
labels: list of integer; each integer identifies the ground truth.
"""
logging.info('Determining list of input files and labels from %s.' % data_dir)
unique_labels = [l.strip() for l in tf.gfile.FastGFile(labels_file, 'r').readlines()]
labels = []
filenames = []
texts = []
# Leave label index 0 empty as a background class.
"""非常重要,这里我们调整label从0开始以符合定义"""
label_index = command_args.class_label_base
# Construct the list of JPEG files and labels.
for text in unique_labels:
jpeg_file_path = '%s/%s/*' % (data_dir, text)
matching_files = tf.gfile.Glob(jpeg_file_path)
labels.extend([label_index] * len(matching_files))
texts.extend([text] * len(matching_files))
filenames.extend(matching_files)
if not label_index % 100:
logging.info('Finished finding files in %d of %d classes.' % (label_index, len(labels)))
label_index += 1
# Shuffle the ordering of all image files in order to guarantee
# random ordering of the images with respect to label in the
# saved TFRecord files. Make the randomization repeatable.
shuffled_index = list(range(len(filenames)))
random.seed(12345)
random.shuffle(shuffled_index)
filenames = [filenames[i] for i in shuffled_index]
texts = [texts[i] for i in shuffled_index]
labels = [labels[i] for i in shuffled_index]
logging.info('Found %d JPEG files across %d labels inside %s.' %
(len(filenames), len(unique_labels), data_dir))
# print(labels)
return filenames, texts, labels
def _process_dataset(name, directory, num_shards, labels_file, command_args):
"""Process a complete data set and save it as a TFRecord.
Args:
name: string, unique identifier specifying the data set.
directory: string, root path to the data set.
num_shards: integer number of shards for this data set.
labels_file: string, path to the labels file.
"""
filenames, texts, labels = _find_image_files(directory, labels_file, command_args)
_process_image_files(name, filenames, texts, labels, num_shards, command_args)
def check_and_set_default_args(command_args):
if not(hasattr(command_args, 'train_shards')) or command_args.train_shards is None:
command_args.train_shards = 5
if not(hasattr(command_args, 'validation_shards')) or command_args.validation_shards is None:
command_args.validation_shards = 5
if not(hasattr(command_args, 'num_threads')) or command_args.num_threads is None:
command_args.num_threads = 5
if not(hasattr(command_args, 'class_label_base')) or command_args.class_label_base is None:
command_args.class_label_base = 0
if not(hasattr(command_args, 'dataset_name')) or command_args.dataset_name is None:
command_args.dataset_name = ''
assert not command_args.train_shards % command_args.num_threads, (
'Please make the command_args.num_threads commensurate with command_args.train_shards')
assert not command_args.validation_shards % command_args.num_threads, (
'Please make the command_args.num_threads commensurate with '
'command_args.validation_shards')
assert command_args.train_directory is not None
assert command_args.validation_directory is not None
assert command_args.labels_file is not None
assert command_args.output_directory is not None
def main(command_args):
"""
command_args:需要有以下属性:
command_args.train_directory 训练集所在的文件夹。这个文件夹下面,每个文件夹的名字代表label名称,再下面就是图片。
command_args.validation_directory 验证集所在的文件夹。这个文件夹下面,每个文件夹的名字代表label名称,再下面就是图片。
command_args.labels_file 一个文件。每一行代表一个label名称。
command_args.output_directory 一个文件夹,表示最后输出的位置。
command_args.train_shards 将训练集分成多少份。
command_args.validation_shards 将验证集分成多少份。
command_args.num_threads 线程数。必须是上面两个参数的约数。
command_args.class_label_base 很重要!真正的tfrecord中,每个class的label号从多少开始,默认为0
(在models/slim中就是从0开始的)
command_args.dataset_name 字符串,输出的时候的前缀。
图片不可以有损坏。否则会导致线程提前退出。
"""
check_and_set_default_args(command_args)
logging.info('Saving results to %s' % command_args.output_directory)
# Run it!
_process_dataset('validation', command_args.validation_directory,
command_args.validation_shards, command_args.labels_file, command_args)
_process_dataset('train', command_args.train_directory,
command_args.train_shards, command_args.labels_file, command_args)
结果会生成:
五个文件。
代码是书中的代码,但是需要改动:
//第一
def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
value=tf.compat.as_bytes(value)//这行需要添加
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
//第二
def _process_image(filename, coder):
with open(filename, 'rb') as f://这里需要加个b
image_data = f.read()
//第三
xrange需要都改为range
//第四
_find_image_files:
shuffled_index = list(range(len(filenames)))//这里加上了list
否则会报错。
制作完数据后要进行训练了。
使用tensorflow slim 微调模型。
TensorFlow Slim是Google公司公布的一个图像分类工具包,它不仅定义了一些方便的接口,还提供了很多ImageNet数据集上常用的网络结构和预训练模型。截至2017年7月,Slim提供包括VGG16, VGG19, InceptionV1一V4, ResNet 50, ResNet 101 , MobileNet在内大多数常用模型的结构以及预训练模型,更多的模型还会被持续添加进来。
先介绍如何下载Slim的源代码,再介绍如何在Slim中定义新的数据库,最后介绍如何使用新的数据库训练以及如何进行参数调整。
如果需要使用Slim微调模型,首先要下载Slim的源代码。
Slim的源代码保存在tensorflow/models项目中,可以使用下面的git 命令下载tensorflow/models:
git clone https://github.com/tensorflow/models.git
找到models/research/目录中的slim文件夹,这就是要用到的TensorFlow Slim的源代码。
这里简单介绍TensorFlow Slim的代码结构:
表只列出了TensorFlow Slim中最重要的几个文件以及文件夹的作用。
其他还有少量文件和文件夹,如果你对它们的作用感兴趣,可以自行参阅其文档。
定义新的datasets文件
在slim/datasets中,定义了所有可以使用的数据库,为了使用自己创建的tfrecord数据进行训练,必须要在datasets中定义新的数据库。
首先,在datasets/目录下新建一个文件satellite.py,并将flowers.py文件中的内容复制到satellite.py中。
flowers.py
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset.
The dataset scripts used to create the dataset can be found at:
tensorflow/models/slim/datasets/download_and_convert_flowers.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from datasets import dataset_utils
slim = tf.contrib.slim
_FILE_PATTERN = 'flowers_%s_*.tfrecord'
SPLITS_TO_SIZES = {'train': 3320, 'validation': 350}
_NUM_CLASSES = 5
_ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying size.',
'label': 'A single integer between 0 and 4',
}
def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
"""Gets a dataset tuple with instructions for reading flowers.
Args:
split_name: A train/validation split name.
dataset_dir: The base directory of the dataset sources.
file_pattern: The file pattern to use when matching the dataset sources.
It is assumed that the pattern contains a '%s' string so that the split
name can be inserted.
reader: The TensorFlow reader type.
Returns:
A `Dataset` namedtuple.
Raises:
ValueError: if `split_name` is not a valid train/validation split.
"""
if split_name not in SPLITS_TO_SIZES:
raise ValueError('split name %s was not recognized.' % split_name)
if not file_pattern:
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
# Allowing None in the signature so that dataset_factory can use the default.
if reader is None:
reader = tf.TFRecordReader
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir)
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
num_classes=_NUM_CLASSES,
labels_to_names=labels_to_names)
接下来,需要修改以下几处内容:
第一处是_FILE- PATTERN, SPLITS_ TO_ SIZES,_ NUM_ CLASSES,将其进行以下修改:
_FILE_PATTERN变量定义了数据的文件名的格式和训练集、验证集的数量。
这里定义_FILE_ PATTERN="satellite_ %s_ *.tfrecord"和SPLITS_ TO_SIZES={'train': 4800, 'validation': 1200 },
就表明数据集中,训练集的文件名格式为satellite train *.tfrecord,共包合4800张图片,验证集文件名格式为satellite validation *.tfrecord,共包含1200张图片。
_NUM_CLASSES变量定义了数据集中图片的类别数目。
第二处修改为image/format部分,将之修改为:
此处定义了图片的默认格式。
最后,读者也可以对文件中的注释内容进行合适的修改。
修改完satellite.py后,还需要在同目录的dataset_factory.py文件中注册satellite数据库。
dataset_factory.py
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A factory-pattern class which returns classification image/label pairs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import satellite
datasets_map = {
'cifar10': cifar10,
'flowers': flowers,
'imagenet': imagenet,
'mnist': mnist,
'satellite': satellite,
}
def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None):
"""Given a dataset name and a split_name returns a Dataset.
Args:
name: String, the name of the dataset.
split_name: A train/test split name.
dataset_dir: The directory where the dataset files are stored.
file_pattern: The file pattern to use for matching the dataset source files.
reader: The subclass of tf.ReaderBase. If left as `None`, then the default
reader defined by each dataset is used.
Returns:
A `Dataset` class.
Raises:
ValueError: If the dataset `name` is unknown.
"""
if name not in datasets_map:
raise ValueError('Name of dataset unknown %s' % name)
return datasets_map[name].get_split(
split_name,
dataset_dir,
file_pattern,
reader)
未修改的dataset_ factory.py中注册数据库的对应代码为:
很显然,此时只注册了4个数据库,对这部分进行修改,将satellite模块也添加进来就可以了:
准备训练文件夹
定义完数据集后,在slim文件夹下再新建一个satellite目录,在这个目录中,完成最后的几项准备工作:
(1)新建一个data目录,并将准备好的5个转换好格式的训练数据复制进去。
(2)新建一个空的train_dir目录,用来保存训练过程中的日志和模型。
(3)新建一个pretrained目录,找到Inception V3模型.
下载地址https://download.csdn.net/download/m0_37407756/10494617,
下载并解压后,会得到一个inception_v3.ckpt文件,将该文件复制到pretrained目录下。
最后形成的目录结构为:
开始训练
在slim文件夹下,运行以下命令就可以开始训练了:
python3 train_image_classifier.py \
--train_dir=satellite/train_dir \
--dataset_name=satellite \
--dataset_split_name=train \
--dataset_dir=satellite/data \
--model_name=inception_v3 \
--checkpoint_path=satellite/pretrained/inception_v3.ckpt \
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--max_number_of_steps=100000 \
--batch_size =32 \
--learning_rate=0.001 \
--learning_rate_decay_type=fixed \
--save_summaries_secs=2 \
-log_every_n_steps=10 \
--optimizer=rmsprop \
--weight_decay=0.00004
还可以使用以下命令对所有层进行训练:
对比只训练末端层的命令,只有一处发生了变化,即去掉了--trainable-scopes参数。
原先的一trainable_ scopes=InceptionV3/Logits,InceptionV3/AuxLogits表示只对末端层InceptionV3/Logits和
InceptionV3/AuxLogits进行训练,去掉后就可以训练模型中的所有参数了.
train_image_classifier.py
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generic training script that trains a model using a given dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from datasets import dataset_factory
from deployment import model_deploy
from nets import nets_factory
from preprocessing import preprocessing_factory
slim = tf.contrib.slim
tf.app.flags.DEFINE_string(
'master', '', 'The address of the TensorFlow master to use.')
tf.app.flags.DEFINE_string(
'train_dir', '/tmp/tfmodel/',
'Directory where checkpoints and event logs are written to.')
tf.app.flags.DEFINE_integer('num_clones', 1,
'Number of model clones to deploy.')
tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
'Use CPUs to deploy clones.')
tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.')
tf.app.flags.DEFINE_integer(
'num_ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
tf.app.flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
tf.app.flags.DEFINE_integer(
'num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
tf.app.flags.DEFINE_integer(
'log_every_n_steps', 10,
'The frequency with which logs are print.')
tf.app.flags.DEFINE_integer(
'save_summaries_secs', 600,
'The frequency with which summaries are saved, in seconds.')
tf.app.flags.DEFINE_integer(
'save_interval_secs', 600,
'The frequency with which the model is saved, in seconds.')
tf.app.flags.DEFINE_integer(
'task', 0, 'Task id of the replica running the training.')
######################
# Optimization Flags #
######################
tf.app.flags.DEFINE_float(
'weight_decay', 0.00004, 'The weight decay on the model weights.')
tf.app.flags.DEFINE_string(
'optimizer', 'rmsprop',
'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
'"ftrl", "momentum", "sgd" or "rmsprop".')
tf.app.flags.DEFINE_float(
'adadelta_rho', 0.95,
'The decay rate for adadelta.')
tf.app.flags.DEFINE_float(
'adagrad_initial_accumulator_value', 0.1,
'Starting value for the AdaGrad accumulators.')
tf.app.flags.DEFINE_float(
'adam_beta1', 0.9,
'The exponential decay rate for the 1st moment estimates.')
tf.app.flags.DEFINE_float(
'adam_beta2', 0.999,
'The exponential decay rate for the 2nd moment estimates.')
tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')
tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
'The learning rate power.')
tf.app.flags.DEFINE_float(
'ftrl_initial_accumulator_value', 0.1,
'Starting value for the FTRL accumulators.')
tf.app.flags.DEFINE_float(
'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')
tf.app.flags.DEFINE_float(
'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')
tf.app.flags.DEFINE_float(
'momentum', 0.9,
'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
#######################
# Learning Rate Flags #
#######################
tf.app.flags.DEFINE_string(
'learning_rate_decay_type',
'exponential',
'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
' or "polynomial"')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
tf.app.flags.DEFINE_float(
'end_learning_rate', 0.0001,
'The minimal end learning rate used by a polynomial decay learning rate.')
tf.app.flags.DEFINE_float(
'label_smoothing', 0.0, 'The amount of label smoothing.')
tf.app.flags.DEFINE_float(
'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')
tf.app.flags.DEFINE_float(
'num_epochs_per_decay', 2.0,
'Number of epochs after which learning rate decays.')
tf.app.flags.DEFINE_bool(
'sync_replicas', False,
'Whether or not to synchronize the replicas during training.')
tf.app.flags.DEFINE_integer(
'replicas_to_aggregate', 1,
'The Number of gradients to collect before updating params.')
tf.app.flags.DEFINE_float(
'moving_average_decay', None,
'The decay to use for the moving average.'
'If left as None, then moving averages are not used.')
#######################
# Dataset Flags #
#######################
tf.app.flags.DEFINE_string(
'dataset_name', 'imagenet', 'The name of the dataset to load.')
tf.app.flags.DEFINE_string(
'dataset_split_name', 'train', 'The name of the train/test split.')
tf.app.flags.DEFINE_string(
'dataset_dir', None, 'The directory where the dataset files are stored.')
tf.app.flags.DEFINE_integer(
'labels_offset', 0,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.')
tf.app.flags.DEFINE_string(
'model_name', 'inception_v3', 'The name of the architecture to train.')
tf.app.flags.DEFINE_string(
'preprocessing_name', None, 'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.')
tf.app.flags.DEFINE_integer(
'batch_size', 32, 'The number of samples in each batch.')
tf.app.flags.DEFINE_integer(
'train_image_size', None, 'Train image size')
tf.app.flags.DEFINE_integer('max_number_of_steps', None,
'The maximum number of training steps.')
#####################
# Fine-Tuning Flags #
#####################
tf.app.flags.DEFINE_string(
'checkpoint_path', None,
'The path to a checkpoint from which to fine-tune.')
tf.app.flags.DEFINE_string(
'checkpoint_exclude_scopes', None,
'Comma-separated list of scopes of variables to exclude when restoring '
'from a checkpoint.')
tf.app.flags.DEFINE_string(
'trainable_scopes', None,
'Comma-separated list of scopes to filter the set of variables to train.'
'By default, None would train all the variables.')
tf.app.flags.DEFINE_boolean(
'ignore_missing_vars', False,
'When restoring a checkpoint would ignore missing variables.')
FLAGS = tf.app.flags.FLAGS
def _configure_learning_rate(num_samples_per_epoch, global_step):
"""Configures the learning rate.
Args:
num_samples_per_epoch: The number of samples in each epoch of training.
global_step: The global_step tensor.
Returns:
A `Tensor` representing the learning rate.
Raises:
ValueError: if
"""
decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
FLAGS.num_epochs_per_decay)
if FLAGS.sync_replicas:
decay_steps /= FLAGS.replicas_to_aggregate
if FLAGS.learning_rate_decay_type == 'exponential':
return tf.train.exponential_decay(FLAGS.learning_rate,
global_step,
decay_steps,
FLAGS.learning_rate_decay_factor,
staircase=True,
name='exponential_decay_learning_rate')
elif FLAGS.learning_rate_decay_type == 'fixed':
return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
elif FLAGS.learning_rate_decay_type == 'polynomial':
return tf.train.polynomial_decay(FLAGS.learning_rate,
global_step,
decay_steps,
FLAGS.end_learning_rate,
power=1.0,
cycle=False,
name='polynomial_decay_learning_rate')
else:
raise ValueError('learning_rate_decay_type [%s] was not recognized',
FLAGS.learning_rate_decay_type)
def _configure_optimizer(learning_rate):
"""Configures the optimizer used for training.
Args:
learning_rate: A scalar or `Tensor` learning rate.
Returns:
An instance of an optimizer.
Raises:
ValueError: if FLAGS.optimizer is not recognized.
"""
if FLAGS.optimizer == 'adadelta':
optimizer = tf.train.AdadeltaOptimizer(
learning_rate,
rho=FLAGS.adadelta_rho,
epsilon=FLAGS.opt_epsilon)
elif FLAGS.optimizer == 'adagrad':
optimizer = tf.train.AdagradOptimizer(
learning_rate,
initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
elif FLAGS.optimizer == 'adam':
optimizer = tf.train.AdamOptimizer(
learning_rate,
beta1=FLAGS.adam_beta1,
beta2=FLAGS.adam_beta2,
epsilon=FLAGS.opt_epsilon)
elif FLAGS.optimizer == 'ftrl':
optimizer = tf.train.FtrlOptimizer(
learning_rate,
learning_rate_power=FLAGS.ftrl_learning_rate_power,
initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
l1_regularization_strength=FLAGS.ftrl_l1,
l2_regularization_strength=FLAGS.ftrl_l2)
elif FLAGS.optimizer == 'momentum':
optimizer = tf.train.MomentumOptimizer(
learning_rate,
momentum=FLAGS.momentum,
name='Momentum')
elif FLAGS.optimizer == 'rmsprop':
optimizer = tf.train.RMSPropOptimizer(
learning_rate,
decay=FLAGS.rmsprop_decay,
momentum=FLAGS.momentum,
epsilon=FLAGS.opt_epsilon)
elif FLAGS.optimizer == 'sgd':
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
else:
raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
return optimizer
def _get_init_fn():
"""Returns a function run by the chief worker to warm-start the training.
Note that the init_fn is only run when initializing the model during the very
first global step.
Returns:
An init function run by the supervisor.
"""
if FLAGS.checkpoint_path is None:
return None
# Warn the user if a checkpoint exists in the train_dir. Then we'll be
# ignoring the checkpoint anyway.
if tf.train.latest_checkpoint(FLAGS.train_dir):
tf.logging.info(
'Ignoring --checkpoint_path because a checkpoint already exists in %s'
% FLAGS.train_dir)
return None
exclusions = []
if FLAGS.checkpoint_exclude_scopes:
exclusions = [scope.strip()
for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
# TODO(sguada) variables.filter_variables()
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path
tf.logging.info('Fine-tuning from %s' % checkpoint_path)
return slim.assign_from_checkpoint_fn(
checkpoint_path,
variables_to_restore,
ignore_missing_vars=FLAGS.ignore_missing_vars)
def _get_variables_to_train():
"""Returns a list of variables to train.
Returns:
A list of variables to train by the optimizer.
"""
if FLAGS.trainable_scopes is None:
return tf.trainable_variables()
else:
scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]
variables_to_train = []
for scope in scopes:
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
variables_to_train.extend(variables)
return variables_to_train
def main(_):
if not FLAGS.dataset_dir:
raise ValueError('You must supply the dataset directory with --dataset_dir')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default():
#######################
# Config model_deploy #
#######################
deploy_config = model_deploy.DeploymentConfig(
num_clones=FLAGS.num_clones,
clone_on_cpu=FLAGS.clone_on_cpu,
replica_id=FLAGS.task,
num_replicas=FLAGS.worker_replicas,
num_ps_tasks=FLAGS.num_ps_tasks)
# Create global_step
with tf.device(deploy_config.variables_device()):
global_step = slim.create_global_step()
######################
# Select the dataset #
######################
dataset = dataset_factory.get_dataset(
FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
######################
# Select the network #
######################
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
weight_decay=FLAGS.weight_decay,
is_training=True)
#####################################
# Select the preprocessing function #
#####################################
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=True)
##############################################################
# Create a dataset provider that loads data from the dataset #
##############################################################
with tf.device(deploy_config.inputs_device()):
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=FLAGS.num_readers,
common_queue_capacity=20 * FLAGS.batch_size,
common_queue_min=10 * FLAGS.batch_size)
[image, label] = provider.get(['image', 'label'])
label -= FLAGS.labels_offset
train_image_size = FLAGS.train_image_size or network_fn.default_image_size
image = image_preprocessing_fn(image, train_image_size, train_image_size)
images, labels = tf.train.batch(
[image, label],
batch_size=FLAGS.batch_size,
num_threads=FLAGS.num_preprocessing_threads,
capacity=5 * FLAGS.batch_size)
labels = slim.one_hot_encoding(
labels, dataset.num_classes - FLAGS.labels_offset)
batch_queue = slim.prefetch_queue.prefetch_queue(
[images, labels], capacity=2 * deploy_config.num_clones)
####################
# Define the model #
####################
def clone_fn(batch_queue):
"""Allows data parallelism by creating multiple clones of network_fn."""
with tf.device(deploy_config.inputs_device()):
images, labels = batch_queue.dequeue()
logits, end_points = network_fn(images)
#############################
# Specify the loss function #
#############################
if 'AuxLogits' in end_points:
tf.losses.softmax_cross_entropy(
logits=end_points['AuxLogits'], onehot_labels=labels,
label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels,
label_smoothing=FLAGS.label_smoothing, weights=1.0)
return end_points
# Gather initial summaries.
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
first_clone_scope = deploy_config.clone_scope(0)
# Gather update_ops from the first clone. These contain, for example,
# the updates for the batch_norm variables created by network_fn.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
# Add summaries for end_points.
end_points = clones[0].outputs
for end_point in end_points:
x = end_points[end_point]
summaries.add(tf.summary.histogram('activations/' + end_point, x))
summaries.add(tf.summary.scalar('sparsity/' + end_point,
tf.nn.zero_fraction(x)))
# Add summaries for losses.
for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
# Add summaries for variables.
for variable in slim.get_model_variables():
summaries.add(tf.summary.histogram(variable.op.name, variable))
#################################
# Configure the moving averages #
#################################
if FLAGS.moving_average_decay:
moving_average_variables = slim.get_model_variables()
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, global_step)
else:
moving_average_variables, variable_averages = None, None
#########################################
# Configure the optimization procedure. #
#########################################
with tf.device(deploy_config.optimizer_device()):
learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
optimizer = _configure_optimizer(learning_rate)
summaries.add(tf.summary.scalar('learning_rate', learning_rate))
if FLAGS.sync_replicas:
# If sync_replicas is enabled, the averaging will be done in the chief
# queue runner.
optimizer = tf.train.SyncReplicasOptimizer(
opt=optimizer,
replicas_to_aggregate=FLAGS.replicas_to_aggregate,
variable_averages=variable_averages,
variables_to_average=moving_average_variables,
replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
total_num_replicas=FLAGS.worker_replicas)
elif FLAGS.moving_average_decay:
# Update ops executed locally by trainer.
update_ops.append(variable_averages.apply(moving_average_variables))
# Variables to train.
variables_to_train = _get_variables_to_train()
# and returns a train_tensor and summary_op
total_loss, clones_gradients = model_deploy.optimize_clones(
clones,
optimizer,
var_list=variables_to_train)
# Add total_loss to summary.
summaries.add(tf.summary.scalar('total_loss', total_loss))
# Create gradient updates.
grad_updates = optimizer.apply_gradients(clones_gradients,
global_step=global_step)
update_ops.append(grad_updates)
update_op = tf.group(*update_ops)
with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss, name='train_op')
# Add the summaries from the first clone. These contain the summaries
# created by model_fn and either optimize_clones() or _gather_clone_loss().
summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
first_clone_scope))
# Merge all summaries together.
summary_op = tf.summary.merge(list(summaries), name='summary_op')
###########################
# Kicks off the training. #
###########################
slim.learning.train(
train_tensor,
logdir=FLAGS.train_dir,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
init_fn=_get_init_fn(),
summary_op=summary_op,
number_of_steps=FLAGS.max_number_of_steps,
log_every_n_steps=FLAGS.log_every_n_steps,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs,
sync_optimizer=optimizer if FLAGS.sync_replicas else None)
if __name__ == '__main__':
tf.app.run()
训练程序行为
当train_image_classifier.py程序启动后,如果训练文件夹(即satellite/train_dir)里没有已经保存的模型,就会加载checkpoint_ path中的预训练模型,紧接着,程序会把初始模型保存到train_ dir中,命名为model.ckpt-0, 0表示第0步。这之后,每隔5min(参数一save_ interval_ secs=300指定了每隔300s保存一次,即5min )。程序还会把当前模型保存到同样的文件夹中命名格式和第一次保存的格式一样。因为模型比较大,程序只会保留最新5个模型。 此外,如果中断了程序并再次运行,程序会首先检查train_dir中有无已经保存的模型,如果有,就不会去加载checkpoint_path中的预训练模型,而是直接加载train_dir中已经训练好的模型,并以此为起点进行训练。Slim之所以这样设计,是为了在微调网络的时候,可以方便地按阶段手动调整学习率等参数。
TensorBoard可视化与超参数选择
使用下列命令可以打开TensorBoard(其实就是指定训练文件夹):
tensorboard --logdir satellite/train_dir
在TensorBoard中,可以看到损失的变化曲线。观察指失曲线有助于调整参数。
当损失曲线比较平缓,收敛较慢时,可以考虑增大学习率,以加快收敛速度;
如果损失曲线波动较大,无法收敛,就可能是因为学习率过大,此时就可以尝试适当减小学习率。
验证模型准确率
用evil_image_classitier.py程序进行验证,即执行下列命令:
python3 eval_image_classifier.py \
--checkpoint_path=satellite/train_dir \
--eval_dir=satellite/eval_dir \
--dataset_name=satellite \
--dataset_split_name=validation \
--dataset_dir=satellite/data \
--model_name=inception_v3
evil_image_classitier.py
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generic evaluation script that evaluates a model using a given dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
from datasets import dataset_factory
from nets import nets_factory
from preprocessing import preprocessing_factory
slim = tf.contrib.slim
tf.app.flags.DEFINE_integer(
'batch_size', 100, 'The number of samples in each batch.')
tf.app.flags.DEFINE_integer(
'max_num_batches', None,
'Max number of batches to evaluate by default use all.')
tf.app.flags.DEFINE_string(
'master', '', 'The address of the TensorFlow master to use.')
tf.app.flags.DEFINE_string(
'checkpoint_path', '/tmp/tfmodel/',
'The directory where the model was written to or an absolute path to a '
'checkpoint file.')
tf.app.flags.DEFINE_string(
'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.')
tf.app.flags.DEFINE_integer(
'num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
tf.app.flags.DEFINE_string(
'dataset_name', 'imagenet', 'The name of the dataset to load.')
tf.app.flags.DEFINE_string(
'dataset_split_name', 'test', 'The name of the train/test split.')
tf.app.flags.DEFINE_string(
'dataset_dir', None, 'The directory where the dataset files are stored.')
tf.app.flags.DEFINE_integer(
'labels_offset', 0,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.')
tf.app.flags.DEFINE_string(
'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
tf.app.flags.DEFINE_string(
'preprocessing_name', None, 'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.')
tf.app.flags.DEFINE_float(
'moving_average_decay', None,
'The decay to use for the moving average.'
'If left as None, then moving averages are not used.')
tf.app.flags.DEFINE_integer(
'eval_image_size', None, 'Eval image size')
FLAGS = tf.app.flags.FLAGS
def main(_):
if not FLAGS.dataset_dir:
raise ValueError('You must supply the dataset directory with --dataset_dir')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default():
tf_global_step = slim.get_or_create_global_step()
######################
# Select the dataset #
######################
dataset = dataset_factory.get_dataset(
FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
####################
# Select the model #
####################
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=False)
##############################################################
# Create a dataset provider that loads data from the dataset #
##############################################################
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=False,
common_queue_capacity=2 * FLAGS.batch_size,
common_queue_min=FLAGS.batch_size)
[image, label] = provider.get(['image', 'label'])
label -= FLAGS.labels_offset
#####################################
# Select the preprocessing function #
#####################################
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=False)
eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size
image = image_preprocessing_fn(image, eval_image_size, eval_image_size)
images, labels = tf.train.batch(
[image, label],
batch_size=FLAGS.batch_size,
num_threads=FLAGS.num_preprocessing_threads,
capacity=5 * FLAGS.batch_size)
####################
# Define the model #
####################
logits, _ = network_fn(images)
if FLAGS.moving_average_decay:
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, tf_global_step)
variables_to_restore = variable_averages.variables_to_restore(
slim.get_model_variables())
variables_to_restore[tf_global_step.op.name] = tf_global_step
else:
variables_to_restore = slim.get_variables_to_restore()
predictions = tf.argmax(logits, 1)
labels = tf.squeeze(labels)
# Define the metrics:
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
'Recall_5': slim.metrics.streaming_recall_at_k(
logits, labels, 5),
})
# Print the summaries to screen.
for name, value in names_to_values.items():
summary_name = 'eval/%s' % name
op = tf.summary.scalar(summary_name, value, collections=[])
op = tf.Print(op, [value], summary_name)
tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)
# TODO(sguada) use num_epochs=1
if FLAGS.max_num_batches:
num_batches = FLAGS.max_num_batches
else:
# This ensures that we make a single pass over all of the data.
num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path
tf.logging.info('Evaluating %s' % checkpoint_path)
slim.evaluation.evaluate_once(
master=FLAGS.master,
checkpoint_path=checkpoint_path,
logdir=FLAGS.eval_dir,
num_evals=num_batches,
eval_op=list(names_to_updates.values()),
variables_to_restore=variables_to_restore)
# slim.evaluation.evaluation_loop(
# master=FLAGS.master,
# checkpoint_dir=FLAGS.checkpoint_path,
# logdir=FLAGS.eval_dir,
# num_evals=num_batches,
# eval_op=list(names_to_updates.values()),
# variables_to_restore=variables_to_restore,
# eval_interval_secs=300
# )
if __name__ == '__main__':
tf.app.run()
执行后,应该会出现类似下面的结果:
Accuracy表示模型的分类准确率,而Recall_ 5表示Top 5的准确率,即
在输出的各类别概率中,正确的类别只要落在前5个就算对。
由于此处的类别数比较少,因此可以不执行Top 5的准确率,换而执行Top 2或者Top 3的准确率,
只要在eval_image_classifier.py中修改下面的部分就可以了:
导出模型并对单张图片进行识别
训练完模型后,常见的应用场景是:部署训练好的模型并对单张图片做识别。
这里提供了两个代码文件:freeze_ graph.py和classify_image-inception一 v3.py。
前者可以导出一个用于识别的模型,后者则是使用inception-v3模型对单张图片做识别的脚本。
TensorFlow Slim为提供了导出网络结构的脚本export-inference_ graph.py。首先在slim文件夹下运行:
python3 export_inference_graph.py \
> --alsologtostderr \
> --model_name=inception_v3 \
> --output_file=satellite/inception_v3_inf_graph.pb \
> --dataset_name satellite
export-inference_ graph.py
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Saves out a GraphDef containing the architecture of the model.
To use it, run something like this, with a model name defined by slim:
bazel build tensorflow_models/slim:export_inference_graph
bazel-bin/tensorflow_models/slim/export_inference_graph \
--model_name=inception_v3 --output_file=/tmp/inception_v3_inf_graph.pb
If you then want to use the resulting model with your own or pretrained
checkpoints as part of a mobile model, you can run freeze_graph to get a graph
def with the variables inlined as constants using:
bazel build tensorflow/python/tools:freeze_graph
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/tmp/inception_v3_inf_graph.pb \
--input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \
--input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \
--output_node_names=InceptionV3/Predictions/Reshape_1
The output node names will vary depending on the model, but you can inspect and
estimate them using the summarize_graph tool:
bazel build tensorflow/tools/graph_transforms:summarize_graph
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
--in_graph=/tmp/inception_v3_inf_graph.pb
To run the resulting graph in C++, you can look at the label_image sample code:
bazel build tensorflow/examples/label_image:label_image
bazel-bin/tensorflow/examples/label_image/label_image \
--image=${HOME}/Pictures/flowers.jpg \
--input_layer=input \
--output_layer=InceptionV3/Predictions/Reshape_1 \
--graph=/tmp/frozen_inception_v3.pb \
--labels=/tmp/imagenet_slim_labels.txt \
--input_mean=0 \
--input_std=255 \
--logtostderr
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.platform import gfile
from datasets import dataset_factory
from nets import nets_factory
slim = tf.contrib.slim
tf.app.flags.DEFINE_string(
'model_name', 'inception_v3', 'The name of the architecture to save.')
tf.app.flags.DEFINE_boolean(
'is_training', False,
'Whether to save out a training-focused version of the model.')
tf.app.flags.DEFINE_integer(
'default_image_size', 224,
'The image size to use if the model does not define it.')
tf.app.flags.DEFINE_string('dataset_name', 'imagenet',
'The name of the dataset to use with the model.')
tf.app.flags.DEFINE_integer(
'labels_offset', 0,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.')
tf.app.flags.DEFINE_string(
'output_file', '', 'Where to save the resulting file to.')
tf.app.flags.DEFINE_string(
'dataset_dir', '', 'Directory to save intermediate dataset files to')
FLAGS = tf.app.flags.FLAGS
def main(_):
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'validation',
FLAGS.dataset_dir)
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training)
if hasattr(network_fn, 'default_image_size'):
image_size = network_fn.default_image_size
else:
image_size = FLAGS.default_image_size
placeholder = tf.placeholder(name='input', dtype=tf.float32,
shape=[1, image_size, image_size, 3])
network_fn(placeholder)
graph_def = graph.as_graph_def()
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())
if __name__ == '__main__':
tf.app.run()
这个命令会在satellite文件夹中生成一个inception_ v3_inf_graph.pb文件
注意:inception_v3_ inf_graph.pb文件中只保存了Inception V3的网络结构,
并不包含训练得到的模型参数,需要将checkpoint中的模型参数保存进来。
方法是使用freeze_ graph.py脚本:(在chapter-3文件夹下运行):
python3 freeze_graph.py \
> --input_graph slim/satellite/inception_v3_inf_graph.pb \
> --input_checkpoint slim/satellite/train_dir/model.ckpt-100000 \
> --input_binary true \
> --output_node_names InceptionV3/Predictions/Reshape_1 \
> --output_graph slim/satellite/frozen_graph.pb
解释:
> --input_graph slim/satellite/inception_v3_inf_graph.pb :
它表示使用的网络结构文件inception_v3_inf_graph.pb 即之前已经导出的。
> --input_checkpoint slim/satellite/train_dir/model.ckpt-100000 :
具体将哪一个checkpoint的参数载入到网络结构中。
> --input_binary true:
导入的inception_v3_inf_graph.pb实际是一个protobuf 文件。
而protobuf文件有两种保存格式,一种是文本形式,一种是二进制形式。
inception_v3_inf_graph.pb是二进制形式,所以对应的参数是--input_binary true。
--output_node_names InceptionV3/Predictions/Reshape_1:
在导出的模型中,指定一个输出的结点, InceptionV3/Predictions/Reshape_1是 InceptionV3的最后的输出层。
> --output_graph slim/satellite/frozen_graph.pb
最后导出的模型保存在slim/satellite/frozen_graph.pb文件。
使用classify_image_inception_v3.py 完成预测:
python3 classify_image_inception_v3.py \
> --model_path slim/satellite/frozen_graph.pb \
> --label_path data_prepare/pic/label.txt \
> --image_file test_image.jpg
解释:
--model-path很好理解,就是之前导出的模型frozen-graph.pb。模型的输出实际是“第0类”、“第1类”……所以用
一label -path指定了一个label文件,label文件中按顺序存储了各个类别的名称,这样脚本就可以把类别的,d号转换为实际的类别名。
--image_ file是需要测试的单张图片。
classify_image_inception_v3.py
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os.path
import re
import sys
import tarfile
import numpy as np
from six.moves import urllib
import tensorflow as tf
FLAGS = None
class NodeLookup(object):
def __init__(self, label_lookup_path=None):
self.node_lookup = self.load(label_lookup_path)
def load(self, label_lookup_path):
node_id_to_name = {}
with open(label_lookup_path) as f:
for index, line in enumerate(f):
node_id_to_name[index] = line.strip()
return node_id_to_name
def id_to_string(self, node_id):
if node_id not in self.node_lookup:
return ''
return self.node_lookup[node_id]
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(FLAGS.model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
def preprocess_for_eval(image, height, width,
central_fraction=0.875, scope=None):
with tf.name_scope(scope, 'eval_image', [image, height, width]):
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# Crop the central region of the image with an area containing 87.5% of
# the original image.
if central_fraction:
image = tf.image.central_crop(image, central_fraction=central_fraction)
if height and width:
# Resize the image to the specified height and width.
image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(image, [height, width],
align_corners=False)
image = tf.squeeze(image, [0])
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.0)
return image
def run_inference_on_image(image):
"""Runs inference on an image.
Args:
image: Image file name.
Returns:
Nothing
"""
with tf.Graph().as_default():
image_data = tf.gfile.FastGFile(image, 'rb').read()
image_data = tf.image.decode_jpeg(image_data)
image_data = preprocess_for_eval(image_data, 299, 299)
image_data = tf.expand_dims(image_data, 0)
with tf.Session() as sess:
image_data = sess.run(image_data)
# Creates graph from saved GraphDef.
create_graph()
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('InceptionV3/Logits/SpatialSqueeze:0')
predictions = sess.run(softmax_tensor,
{'input:0': image_data})
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = NodeLookup(FLAGS.label_path)
top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
def main(_):
image = FLAGS.image_file
run_inference_on_image(image)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_path',
type=str,
)
parser.add_argument(
'--label_path',
type=str,
)
parser.add_argument(
'--image_file',
type=str,
default='',
help='Absolute path to image file.'
)
parser.add_argument(
'--num_top_predictions',
type=int,
default=5,
help='Display this many predictions.'
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
脚本的输出:
这就表示模型预测图片对应的最可能的类别是water,接着是wetland,urban, wood等。
score是各个类别对应的Logit。
在WEB上测试:
这块主要参考了:
https://blog.csdn.net/wlzard/article/details/77689311
server.py生成web
# coding=utf-8
import os
import sys
import importlib
importlib.reload(sys)
from flask import Flask, request, redirect, url_for
import uuid
import tensorflow as tf
from classify_image_inception_v3_test import set_flags, run_inference_on_image
ALLOWED_EXTENSIONS = set(['jpg', 'JPG', 'jpeg', 'JPEG', 'png'])
FLAGS = tf.app.flags.FLAGS
"""Namespace(image_file='test_image.jpg',
label_path='data_prepare/pic/label.txt',
model_path='slim/satellite/frozen_graph.pb',
num_top_predictions=5)"""
tf.app.flags.DEFINE_string('model_path', 'slim/satellite/frozen_graph.pb', """*****, """)
tf.app.flags.DEFINE_string('label_path', 'data_prepare/pic/label.txt', '')
tf.app.flags.DEFINE_string('upload_folder', './didi', '')
tf.app.flags.DEFINE_integer('num_top_predictions', 6,
"""Display this many predictions.""")
tf.app.flags.DEFINE_integer('port', '5001',
'server with port,if no port, use deault port 80')
tf.app.flags.DEFINE_boolean('debug', True, '')
UPLOAD_FOLDER = FLAGS.upload_folder
app = Flask(__name__)
app._static_folder = UPLOAD_FOLDER
def allowed_files(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
def rename_filename(old_file_name):
basename = os.path.basename(old_file_name)
name, ext = os.path.splitext(basename)
new_name = str(uuid.uuid1()) + ext
return new_name
def inference(file_name):
try:
predictions, top_k, top_names = run_inference_on_image(file_name)
#print(predictions)
except Exception as ex:
#print(ex)
return ""
new_url = '/static/%s' % os.path.basename(file_name)
image_tag = ''
new_tag = image_tag % new_url
format_string = ''
for node_id, human_name in zip(top_k, top_names):
score = predictions[node_id]
format_string += '%s (score:%.5f)
' % (human_name, score)
ret_string = new_tag + format_string + '
'
return ret_string
@app.route("/", methods=['GET', 'POST'])
def root():
result = """
临时测试用
来一张照片吧
%s
""" % "
"
if request.method == 'POST':
file = request.files['file']
old_file_name = file.filename
if file and allowed_files(old_file_name):
filename = rename_filename(old_file_name)
file_path = os.path.join(UPLOAD_FOLDER, filename)
file.save(file_path)
type_name = 'N/A'
print('file saved to %s' % file_path)
out_html = inference(file_path)
return result + out_html
return result
if __name__ == "__main__":
set_flags(FLAGS)
print('listening on port %d' % FLAGS.port)
app.run(host='0.0.0.0', port=FLAGS.port, debug=FLAGS.debug, threaded=True)
结果:
结束啦!