目录
目录
TF-slim的测试
数据集制作
数据集验证
数据集注册(引入)
注册数据集
修改slim/train_image_classifier.py文件,进行训练。
运行以下命令,测试tf.contrib.slim模块是否已正确安装。从TF_MODELS/research/slim目录下运行:
python -c "import tensorflow.contrib.slim as slim; eval = slim.evaluation.evaluate_once"
创建目录slim/dataset/cvd,该目录下raw_data文件夹存放下载好的猫狗大战图片,而创建好的tfrecord文件直接存放在cvd目录下。
新建数据集制作文件download_and_convert_cvd_v1_0.py,该文件直接从download_and_convert_flowers.py文件修改而来。代码如下:
r"""Downloads and converts cat_vs_dog data to TFRecords of TF-Example protos.
This module downloads the cat_vs_dog data, uncompresses it, reads the files
that make up the cat_vs_dog data and creates two TFRecord datasets: one for train
and one for test. Each TFRecord dataset is comprised of a set of TF-Example
protocol buffers, each of which contain a single image and label.
The script should take about a minute to run.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
import random
import sys
import tensorflow as tf
from datasets import dataset_utils
# 数据集原始图片的下载网址,本代码使用预先下载好的图片,因此这个字段不用。
_DATA_URL = 'http://url/to/download/xxx.tgz'
# 验证集的图片数,由于猫狗大战数据集的train文件夹共有25000张图片,这里取其中0.3作为验证集。
_NUM_VALIDATION = 7500
# Seed for repeatability.
_RANDOM_SEED = 0
# 指定数据集分成几个tfrecord文件存放
_NUM_SHARDS = 5
# 这个ImageReader类,主要提供两个方法,从gfile读取的*.jpg文件数据解码成图片的tensor数据
class ImageReader(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def read_image_dims(self, sess, image_data):
image = self.decode_jpeg(sess, image_data)
return image.shape[0], image.shape[1]
def decode_jpeg(self, sess, image_data):
image = sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
# 从提供的dataset_dir数据集文件夹读取所有的文件和类别,返回文件的相对路径./dataset_dir/raw_data/*.jpg。
# 默认原始图片存放于dataset_dir/raw_data文件夹下
def _get_filenames_and_classes(dataset_dir):
"""Returns a list of filenames and inferred class names.
Args:
dataset_dir: A directory containing a set of subdirectories representing
class names. Each subdirectory should contain PNG or JPG encoded images.
Returns:
A list of image file paths, relative to `dataset_dir` and the list of
subdirectories, representing class names.
"""
cvd_root = os.path.join(dataset_dir, 'raw_data')
directories = []
class_names = []
photo_filenames = []
for filename in os.listdir(cvd_root):
path = os.path.join(cvd_root, filename)
if os.path.isfile(path):
photo_filenames.append(path)
class_names = ['cat', 'dog']
return photo_filenames, class_names
# 生成tfrecord文件的存储名,返回其相对dataset_dir的路径
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'cat_vs_dog_%s_%05d-of-%05d.tfrecord' % (
split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename)
# 传入分割名'train'/'validation',相应的文件相对路径的列表,类名到类id的映射字典,tfrecord存放的目录dataset_dir
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
"""Converts the given filenames to a TFRecord dataset.
Args:
split_name: The name of the dataset, either 'train' or 'validation'.
filenames: A list of absolute paths to png or jpg images.
class_names_to_ids: A dictionary from class names (strings) to ids
(integers).
dataset_dir: The directory where the converted datasets are stored.
"""
assert split_name in ['train', 'validation']
num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
with tf.Graph().as_default():
image_reader = ImageReader()
with tf.Session('') as sess:
for shard_id in range(_NUM_SHARDS):
output_filename = _get_dataset_filename(
dataset_dir, split_name, shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_ndx = shard_id * num_per_shard
end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
for i in range(start_ndx, end_ndx):
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
i+1, len(filenames), shard_id))
sys.stdout.flush()
# Read the filename:
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
height, width = image_reader.read_image_dims(sess, image_data)
# 根据文件名中是以cat还是dog开始,判断其类别名
# class_name = os.path.basename(os.path.dirname(filenames[i]))
fname = os.path.basename(filenames[i])
class_name = fname.split('.')[0]
class_id = class_names_to_ids[class_name]
print('\r>> File %s of class_name: %s and class_id: %d will be write into tfrecord.'%(filenames[i], class_name, class_id))
example = dataset_utils.image_to_tfexample(
image_data, b'jpg', height, width, class_id)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
def _clean_up_temporary_files(dataset_dir):
"""Removes temporary files used to create the dataset.
Args:
dataset_dir: The directory where the temporary files are stored.
"""
filename = _DATA_URL.split('/')[-1]
filepath = os.path.join(dataset_dir, filename)
tf.gfile.Remove(filepath)
tmp_dir = os.path.join(dataset_dir, 'flower_photos')
tf.gfile.DeleteRecursively(tmp_dir)
def _dataset_exists(dataset_dir):
for split_name in ['train', 'validation']:
for shard_id in range(_NUM_SHARDS):
output_filename = _get_dataset_filename(
dataset_dir, split_name, shard_id)
if not tf.gfile.Exists(output_filename):
return False
return True
def run(dataset_dir):
"""Runs the download and conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
if _dataset_exists(dataset_dir):
print('Dataset files already exist. Exiting without re-creating them.')
return
# dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)
photo_filenames, class_names = _get_filenames_and_classes(dataset_dir)
class_names_to_ids = dict(zip(class_names, range(len(class_names))))
# Divide into train and test:
random.seed(_RANDOM_SEED)
random.shuffle(photo_filenames)
_NUM_VALIDATION = int(0.3*(len(photo_filenames))) # 这里按0.3分配训练和验证数
training_filenames = photo_filenames[_NUM_VALIDATION:]
validation_filenames = photo_filenames[:_NUM_VALIDATION]
# First, convert the training and validation sets.
_convert_dataset('train', training_filenames, class_names_to_ids,
dataset_dir)
_convert_dataset('validation', validation_filenames, class_names_to_ids,
dataset_dir)
# Finally, write the labels file:
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
# _clean_up_temporary_files(dataset_dir)
print('\nFinished converting the cat_vs_dog dataset!')
if __name__ == "__main__":
run('cvd')
在datasets目录下运行download_and_convert_flowers.py文件,将在slim/datasets/cvd目录下生成5个train-*.tfrecord和5个validation-*.tfrecord文件。
注意,在上一步数据集制作过程中,
_convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir)
该函数调用了一个函数:
example = dataset_utils.image_to_tfexample(
image_data, b'jpg', height, width, class_id)
tfrecord_writer.write(example.SerializeToString())
该函数的定义为:
def image_to_tfexample(image_data, image_format, height, width, class_id):
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature(image_format),
'image/class/label': int64_feature(class_id),
'image/height': int64_feature(height),
'image/width': int64_feature(width),
}))
该函数明确了tfrecord中每一条tfexample记录的键值签名。后续在数据读取过程中,也应当按照此键值进行解码。
TODO:读取tfrecord,解码tfexample,并重新显示图片及其类别,以验证tfrecord文件的正确性。
在slim/datasets/目录下新建cvd_v1_0.py(可从flowers.py修改而来)。代码如下:
"""Provides data for the flowers dataset.
The dataset scripts used to create the dataset can be found at:
tensorflow/models/research/slim/datasets/download_and_convert_flowers.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from datasets import dataset_utils
slim = tf.contrib.slim
_FILE_PATTERN = 'cat_vs_dog_%s_*.tfrecord'
SPLITS_TO_SIZES = {'train': 17500, 'validation': 7500}
_NUM_CLASSES = 2
_ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying size.',
'label': 'A single integer between 0 and 4',
}
def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
"""Gets a dataset tuple with instructions for reading flowers.
Args:
split_name: A train/validation split name.
dataset_dir: The base directory of the dataset sources.
file_pattern: The file pattern to use when matching the dataset sources.
It is assumed that the pattern contains a '%s' string so that the split
name can be inserted.
reader: The TensorFlow reader type.
Returns:
A `Dataset` namedtuple.
Raises:
ValueError: if `split_name` is not a valid train/validation split.
"""
if split_name not in SPLITS_TO_SIZES:
raise ValueError('split name %s was not recognized.' % split_name)
# if not file_pattern:
# file_pattern = _FILE_PATTERN
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
# Allowing None in the signature so that dataset_factory can use the default.
if reader is None:
reader = tf.TFRecordReader
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir)
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
num_classes=_NUM_CLASSES,
labels_to_names=labels_to_names)
该文件主要实现了一个函数,该函数返回一个slim.dataset.Dataset()类,供调用。
def get_split(split_name, dataset_dir, file_pattern=None, reader=None)
主要修改dataset_factory.py文件,添加cvd数据集从代码到文件的路线。dataset_name='cvd',并映射到相应的程序数据集slim.dataset.Dataset()生成文件cvd_v1_0.py。
"""A factory-pattern class which returns classification image/label pairs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import cvd_v1_0
datasets_map = {
'cifar10': cifar10,
'flowers': flowers,
'imagenet': imagenet,
'mnist': mnist,
'cvd': cvd_v1_0,
}
def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None):
"""Given a dataset name and a split_name returns a Dataset.
Args:
name: String, the name of the dataset.
split_name: A train/test split name.
dataset_dir: The directory where the dataset files are stored.
file_pattern: The file pattern to use for matching the dataset source files.
reader: The subclass of tf.ReaderBase. If left as `None`, then the default
reader defined by each dataset is used.
Returns:
A `Dataset` class.
Raises:
ValueError: If the dataset `name` is unknown.
"""
if name not in datasets_map:
raise ValueError('Name of dataset unknown %s' % name)
return datasets_map[name].get_split(
split_name,
dataset_dir,
file_pattern,
reader)
该文件主要实现了函数:
def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None)
实现从dataset_name和split_name到slim.dataset.Dataset()的映射。至此,数据集准备完毕。
主要修改