目录
简介:
1、数据集制作
2、slim修改及训练
训练:
3、模型导出
使用官方bazel模型导出:
使用tensorflow模块功能导出
本文将记录分类样本如何制作为tfrecord格式,已经如何用tensorflow的slim模块训练分类模型,把模型固化导出。
环境准备:
样本准备:
我这里准备了天干样本,建立了甲、乙、丙、丁、戊、己、庚、辛、壬、癸10个类别,(类别只为演示用),样本图片为.jpg后缀图片,图片名称及大小无限制, 需要根据类别存储到对应文件夹内。
样本分类如下图:
在 <安装目录>\models-master\research\slim\datasets\ 文件夹内,建立一个新的convert_mydataset.py文件,文件全部内容如下
#coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
import random
import sys
import tensorflow as tf
from datasets import dataset_utils
# The URL where the Flowers data can be downloaded.
_DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
# The number of images in the validation set.
_NUM_VALIDATION = 350
# Seed for repeatability.
_RANDOM_SEED = 0
# The number of shards per dataset split.
_NUM_SHARDS = 5
subname = ['train.txt', 'validation.txt', 'labels.txt']
class ImageReader(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def read_image_dims(self, sess, image_data):
image = self.decode_jpeg(sess, image_data)
return image.shape[0], image.shape[1]
def decode_jpeg(self, sess, image_data):
image = sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _get_filenames_and_classes(dataset_dir):
"""Returns a list of filenames and inferred class names.
Args:
dataset_dir: A directory containing a set of subdirectories representing
class names. Each subdirectory should contain PNG or JPG encoded images.
Returns:
A list of image file paths, relative to `dataset_dir` and the list of
subdirectories, representing class names.
"""
#改为自己的数据集
flower_root = os.path.join(dataset_dir, 'fruit_photos')
directories = []
class_names = []
for filename in os.listdir(flower_root):
path = os.path.join(flower_root, filename)
if os.path.isdir(path):
directories.append(path)
class_names.append(filename)
photo_filenames = []
for directory in directories:
for filename in os.listdir(directory):
path = os.path.join(directory, filename)
photo_filenames.append(path)
return photo_filenames, sorted(class_names)
def _get_dataset_filename(dataset_dir, split_name, shard_id):
#修改为fruit
output_filename = 'mydataset_%s_%05d-of-%05d.tfrecord' % (
split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename)
def _convert_dataset(txtname, split_name, dataset_dir):
"""
Converts the given filenames to a TFRecord dataset.
Args:
split_name: The name of the dataset, either 'train' or 'validation'.
filenames: A list of absolute paths to png or jpg images.
class_names_to_ids: A dictionary from class names (strings) to ids
(integers).
dataset_dir: The directory where the converted datasets are stored.
"""
# 加载文件,仅获取一个label
images_list, labels_list=load_labels_file(txtname,1)
num_per_shard = int(math.ceil(len(images_list) / float(_NUM_SHARDS)))
with tf.Graph().as_default():
image_reader = ImageReader()
with tf.Session('') as sess:
for shard_id in range(_NUM_SHARDS):
#record filename
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_ndx = shard_id * num_per_shard
end_ndx = min((shard_id+1) * num_per_shard, len(images_list))
for i in range(start_ndx, end_ndx):
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i+1, len(images_list), shard_id))
sys.stdout.flush()
filename = os.path.join(dataset_dir,images_list[i])
# Read the filename:
print(filename)
image_data = tf.gfile.FastGFile(filename, 'rb').read()
height, width = image_reader.read_image_dims(sess, image_data)
class_id = labels_list[i]
example = dataset_utils.image_to_tfexample(image_data, b'jpg', height, width, class_id)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
def load_labels_file(filename,labels_num=1,shuffle=False):
'''
载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签2,如:test_image/1.jpg 0 2
:param filename:
:param labels_num :labels个数
:param shuffle :是否打乱顺序
:return:images type->list
:return:labels type->list
'''
images=[]
labels=[]
with open(filename) as f:
lines_list=f.readlines()
if shuffle:
random.shuffle(lines_list)
for lines in lines_list:
line=lines.rstrip().split(' ')
label=[]
for i in range(labels_num):
label.append(int(line[i+1]))
images.append(line[0])
labels.append(label)
return images,labels
def make_train_val_label_txt(ori_dir, frate = 0.8):
# before generate, delete file if exist
for filename in subname:
filedir = os.path.join(ori_dir, filename)
if os.path.exists(filedir):
os.remove(filedir)
#find sub class fordler
directories = []
class_names = []
path_list = os.listdir(ori_dir)
path_list.sort()
for filename in path_list:
path = os.path.join(ori_dir, filename)
if os.path.isdir(path):
directories.append(path)
class_names.append(filename)
print("class_names=\n", class_names)
traindir = os.path.join(ori_dir, subname[0])
valdir = os.path.join(ori_dir, subname[1])
labeldir = os.path.join(ori_dir, subname[2])
with open(labeldir,'a+') as f:
for i, classname in enumerate(class_names):
f.write('%s\n' % (classname))
for i, directory in enumerate(directories):
filenames = []
for filename in os.listdir(directory):
filenames.append(filename)
random.shuffle(filenames)
left = round(len(filenames) *frate+0.5)
trainname = filenames[:left]
trainname.sort()
with open(traindir,'a+') as f:
for name in trainname:
f.write('%s/%s %d\n'%(class_names[i], name, i))
valname = filenames[left:]
valname.sort()
with open(valdir,'a+') as f:
for name in valname:
f.write('%s/%s %d\n'%(class_names[i], name, i))
def run(dataset_dir):
"""Runs the download and conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
#make train val text and label text
make_train_val_label_txt(dataset_dir)
# Get the train and val txt fullname:
train_txt = os.path.join(dataset_dir, subname[0])
val_txt = os.path.join(dataset_dir, subname[1])
# convert the training and validation sets.
_convert_dataset(train_txt, 'train', dataset_dir)
_convert_dataset(val_txt, 'validation',dataset_dir)
print('\nFinished converting the mydataset dataset!')
在slim文件夹下打开download_and_convert_data.py 文件,添加:
from datasets import convert_mydataset
再在def main(_): 函数中添加
elif FLAGS.dataset_name == 'mydataset':
convert_mydataset.run(FLAGS.dataset_dir)
然后再命令行执行:
python download_and_convert_data.py --dataset_name=mydataset --dataset_dir="E:\样本_天干"
没有报错的话,会在样本目录内生成系列.tfrecord文件和train.txt,validation.txt,和labels.txt文件。
将 _FILE_PATTERN = 'flowers_%s_*.tfrecord' 改为: _FILE_PATTERN = 'mydataset_%s_*.tfrecord'。
将 SPLITS_TO_SIZES = {'train': 3320, 'validation': 350} 改为: SPLITS_TO_SIZES ='train': 431, 'validation': 102}其中,train代表训练的图片张数,validation代表验证使用的图片张数,数量要和自己的数据集数量对应,使用本文方法可以通过查看第一步生成的train.txt,validation.txt里面的行数确定。
将 _NUM_CLASSES = 5 改为: _NUM_CLASSES =10 为实际样本类别数
添加:
from datasets import mydataset
datasets_map 字典中添加:
'mydataset': mydataset,
由于训练时文件路径不能包含中文,把1生成的.tfrecord后缀的文件全部拷贝到 e:\HeavenlyStems目录里,训练命令如下:
python3 train_image_classifier.py \
--train_dir=e:\log \ #训练文件保存路径
--dataset_dir=e:\HeavenlyStems \ #样本存放路径
--dataset_name=mydataset \ #样本名称
--dataset_split_name=train \
--model_name="mobilenet_v2_140" \ #模型名称
--checkpoint_path=e:\mobilenet_v2_1.4_224\mobilenet_v2_1.4_224.ckpt \
--checkpoint_exclude_scopes=MobilenetV2/Logits,MobilenetV2/Predictions,MobilenetV2/predics \
--trainable_scopes=MobilenetV2/Logits,MobilenetV2/Predictions,MobilenetV2/predics \
--max_number_of_steps=20000 \ #迭代次数
--preprocessing_name="inception_v2"
--learning_rate=0.045 \
--label_smoothing=0.1 \
--moving_average_decay=0.9999 \
--batch_size=32 \
--num_clones=1 \
--learning_rate_decay_factor=0.98 \
--num_epochs_per_decay=2.5
模型导出提供2种方法:
分两步
第一步: Exporting the Inference Graph
python export_inference_graph.py \
--alsologtostderr \
--dataset_dir=e:\HeavenlyStems \
--dataset_name=mydataset \
--model_name=mobilenet_v2_140\
--image_size=224 \
--output_file=e:\log\mobilenet_v2_244.pb
第二步:Freezing the exported Graph
需要先下载tensorflow源码,并安装对应版本的bazel,在tensorflow源码文件夹内执行命令:
bazel build tensorflow/python/tools:freeze_graph
编译需要等待一段时间,编译成功后,在编译目录执行
bazel-bin/tensorflow/python/tools/freeze_graph
--input_graph=e:\log\mobilenet_v2_244.pb \
--input_checkpoint=e:\log\model.ckpt-20000
--output_graph=e:\log\mobilenet_v2_1.0_224_frozen.pb
--input_binary=True
--output_node_name=MobilenetV2/Predictions/Reshape_1
执行如下命令:
python3 -m tensorflow.python.tools.freeze_graph \
--input_graph e:\log\graph.pbtxt \
--input_checkpoint e:\log\model.ckpt-40856 \
--input_binary false \
--output_graph e:\log\mobilenet_v2_frozen.pb \
--output_node_names MobilenetV2/Predictions/Reshape_1
参考:
tensorflow深度学习实战笔记(一):使用tensorflow slim自带的模型训练自己的数据