TensorFlow Keras 使用Inception-resnet-v2模型训练自己的分类数据集(含源码)

TensorFlow Keras 使用Inception-resnet-v2模型训练自己的分类数据集(含源码)

运行环境

TensorFlow 1.13.1
TensorFlow.Keras 2.2.4-tf

简单介绍

使用TensorFlow自带的Inception-resnet-v2模型训练自己的数据集。数据读取用的是TensorFlow自己的Dataset类,且无需转存成TFrecord格式。使用TensorFlow中的Keras,简单易懂,容易上手

注意事项

  • 先准备好一个分类数据集
  • 使用GPU训练(用CPU应该训练不动Inception-resnet-v2模型,如果没有GPU你可以换成TensorFlow现有的其他模型,但代码需要进行一定的改动)

源码

废话不多说,上源码。
DATA_PATH放数据集路径

ds = ds.prefetch(buffer_size=10*BATCH_SIZE)

这一句用于预读取数据,用的时候注意下CPU和内存,特别是内存,如果百分之九十多了就把程序关了(友情提示Ctrl+C可关闭程序),把这句话注释掉再跑

mport pathlib
import random
import tensorflow as tf

# 训练数据集路径
DATA_PATH = 'E:\dataset'
# 一个批次的大小
BATCH_SIZE = 16
# 启用动态图
tf.enable_eager_execution()
# 数据集路径规范化
data_root = pathlib.Path(DATA_PATH)
# 获取图片的绝对路径
all_image_paths = list(data_root.glob('*/*'))
# 转换成字符串存进列表中
all_image_paths = [str(path) for path in all_image_paths]
# 图片数量
image_count = len(all_image_paths)
# 打乱顺序
random.shuffle(all_image_paths)
# 获取所有类名
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
# 创建类名、索引字典
label_to_index = dict((name, index) for index,name in enumerate(label_names))
# 获取所有的图片类名对应的索引
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]for path in all_image_paths]
# 进行one-hot编码
all_image_labels = tf.one_hot(all_image_labels,170,1,0)

# 图片的预处理函数
def preprocess_image(image):
  image = tf.image.decode_jpeg(image, channels=3)
  image = tf.image.resize(image, [299, 299])
  image /= 255.0  # normalize to [0,1] range
  return image

# 图片的加载函数
def load_and_preprocess_image(path):
  image = tf.io.read_file(path)
  return preprocess_image(image)

# 创建图片的Dataset类用于传输图片数据
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
# 绑定图片的初始化操作(加载和预处理)
image_ds = path_ds.map(load_and_preprocess_image)
# 创建索引的Dataset类
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))
# 将两个Dataset类打包成一个
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
# 设置数据集无限循环使用
ds = image_label_ds.repeat()
# 打乱Dataset中的元素
ds = ds.shuffle(buffer_size=20*BATCH_SIZE)
# 设置Dataset的批次大小
ds = ds.batch(BATCH_SIZE)
# 预读取数据,减少数据读取对训练的影响
ds = ds.prefetch(buffer_size=10*BATCH_SIZE)
# 从Keras中获取现有的网络模型
model = tf.keras.applications.InceptionResNetV2(weights=None,classes=170)
# 网络模型的参数设置
model.compile(
  optimizer=tf.keras.optimizers.RMSprop(),
  loss='binary_crossentropy',
  metrics=['accuracy'])
# 设置训练时的回调函数
class CollectBatchStats(tf.keras.callbacks.Callback):
  def __init__(self):
    self.batch_losses = []
    self.batch_acc = []

  def on_train_batch_end(self, batch, logs=None):
    self.batch_losses.append(logs['loss'])
    self.batch_acc.append(logs['acc'])
    self.model.reset_metrics()

batch_stats_callback = CollectBatchStats()
# 开始训练
history = model.fit(ds, epochs=100,
                    steps_per_epoch=30,
                    callbacks = [batch_stats_callback])

参考文章:https://tensorflow.google.cn/beta/tutorials/load_data/images

你可能感兴趣的:(深度学习随笔,TensorFlow,TensorFlow,Keras,深度学习)