slim是TensorFlow的一个轻量级库,它基于TensorFlow实现了高层封装,将网络、loss、正则化等概念有调理的组织起来,而不是像原生tensorflow底层接口编程那样,到处充满了超参、网络定义、训练循环等。
例如,定义一个卷积:
with tf.name_scope('conv_a') as scope:
kernel = tf.Variable(tf.truncated_normal([5, 5, 32, 64], dtype=tf.float32,
stddev=1e-1), name='weights')
conv = tf.nn.conv2d(input, kernel, [1, 1, 1, 1], padding='SAME')
biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32),
trainable=True, name='biases')
bias = tf.nn.bias_add(conv, biases)
conv1 = tf.nn.relu(bias, name=scope)
可以看到,其中包括了命名空间,权重变量,偏置值变量,激活函数,网络等等重要概念,如果要修改的话会相当麻烦。而使用slim完成同样的卷积,只需要一行代码:
net = slim.conv2d(input, 64, [5, 5], scope='conv_a')
除了通过消除模板代码,允许用户更紧凑地定义模型之外,slim还封装了计算机视觉的几个常见模型(AlexNet,VGGNet,GoogLeNet,ResNet),对于普通用户可以直接当做黑盒来调用,对于有研究需求的用户也可以基于封装以各种方式进行修改和扩展,省去搭建模型的时间。
如果想要学习slim,这篇博客有详细的说明:TensorFlow - TF-Slim 封装模块
create_tfrecord.py 定义了操作tfrecords文件的一些接口
train_model.py 训练模型
predict_test.py 测试模型
slim TF-Slim的拷贝
test_image 存放测试图片
dataset 存放数据集,它的结构如下:
train中是训练集,val是验证集,标签分别保存在相应的txt文件中
文件下载:Slim模型分类
TensorFlow的训练过程就是数据在网络中流动的过程,官方提供了三种数据读取方式,分别是:
在数据量较大的情况下,官方推荐第二种标准的TensorFlow格式(Standard TensorFlow format)存储数据,文件名后缀为tfrecords。本文提供的create_tfrecord.py中提供了几个重要的函数,对于一般的图像分类问题可以直接使用。
本篇文章以VGG16举例。VGG16模型要求数据大小为224x224,设置create_tfrecord.py的参数运行可以直接得到train224.tfrecords和val224.tfrecords。
if __name__ == '__main__':
# 参数设置
resize_height = 224 # 指定存储图片高度
resize_width = 224 # 指定存储图片宽度
shuffle=True
log=5 #打印信息的间隔
# 产生train.record文件
image_dir='dataset/train'
train_labels = 'dataset/train.txt' # 图片路径
train_record_output = 'dataset/record/train224.tfrecords'
create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)
train_nums=get_example_nums(train_record_output)
print("save train example nums={}".format(train_nums))
# 产生val.record文件
image_dir='dataset/val'
val_labels = 'dataset/val.txt' # 图片路径
val_record_output = 'dataset/record/val224.tfrecords'
create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)
val_nums=get_example_nums(val_record_output)
print("save val example nums={}".format(val_nums))
上一步已经将数据制作成了tfrecords文件,以下是train_model.py的详细说明:
首先在代码开头定义关键参数
import tensorflow as tf
import numpy as np
import pdb
import os
from datetime import datetime
import slim.nets.vgg as vgg
from create_tf_record import *
import tensorflow.contrib.slim as slim
#################################
labels_nums = 5 # 类别个数
batch_size = 16 #
resize_height = 224 # 指定存储图片高度
resize_width = 224 # 指定存储图片宽度
depths = 3
data_shape = [batch_size, resize_height, resize_width, depths]
train_record_file='dataset/record/train224.tfrecords'
val_record_file='dataset/record/val224.tfrecords'
train_log_step=100
base_lr = 0.05 # 学习率
max_steps = 10000 # 迭代次数
train_param=[base_lr,max_steps]
val_log_step=200
snapshot=2000#保存文件间隔
snapshot_prefix='models/model.ckpt'
###################################
TensorFlow的程序分为两阶段:定义计算图和执行计算,很多数据是在执行计算的时候才会传入,因此引入了占位符的概念,tf.placeholder()是TensorFlow的占位符节点,由placeholder方法创建,其也是一种常量,但是由用户在调用run方法时传递的,可以简单理解为形参,用于定义过程,在执行的时候再赋具体的值。keep_prob是dropout的概率。is_training=True时,填充train数据进行训练过程,is_training=False时,填充val数据进行验证过程。
# 定义input_images为图片数据
input_images = tf.placeholder(dtype=tf.float32, shape=[None, resize_height, resize_width, depths], name='input')
# 定义input_labels为labels数据
input_labels = tf.placeholder(dtype=tf.int32, shape=[None, labels_nums], name='label')
keep_prob = tf.placeholder(tf.float32,name='keep_prob')
is_training = tf.placeholder(tf.bool, name='is_training')
接下来定义训练函数
def train(train_record_file,
train_log_step,
train_param,
val_record_file,
val_log_step,
labels_nums,
data_shape,
snapshot,
snapshot_prefix):
'''
:param train_record_file: 训练的tfrecord文件
:param train_log_step: 显示训练过程log信息间隔
:param train_param: train参数
:param val_record_file: 验证的tfrecord文件
:param val_log_step: 显示验证过程log信息间隔
:param val_param: val参数
:param labels_nums: labels数
:param data_shape: 输入数据shape
:param snapshot: 保存模型间隔
:param snapshot_prefix: 保存模型文件的前缀名
:return:
'''
create_tfrecords.py中提供了read_records()用来读取tfrecords文件和get_batch_images()获取批数据。
train_images, train_labels = read_records(train_record_file, resize_height, resize_width, type='normalization')
train_images_batch, train_labels_batch = get_batch_images(train_images, train_labels,
batch_size=batch_size, labels_nums=labels_nums,
one_hot=True, shuffle=True)
# val数据,验证数据可以不需要打乱数据
val_images, val_labels = read_records(val_record_file, resize_height, resize_width, type='normalization')
val_images_batch, val_labels_batch = get_batch_images(val_images, val_labels,
batch_size=batch_size, labels_nums=labels_nums,
one_hot=True, shuffle=False)
封装的模型都是用tf.contrib.slim写的,slim.arg_scope用于为tensorflow里的layer函数提供默认参数值。因此,在定义模型之前,需要设置模型的默认参数。训练不同的模型,只需要替换import slim.nets.vgg as vgg,替换如下:
import slim.nets.vgg as vgg
#训练GoogLeNet v1版本改为
import slim.nets.inception_v1 as inception_v1
#训练ResNet v1版本改为
import slim.nets.resnet_v1 as resnet_v1
#训练AlexNet 改为
import slim.nets.alexnet as alexnet
另外,每个模型的入口函数和默认参数不同,训练不同模型时也需要做相应的修改。封装的模型文件在slim/nets目录下,可以查询每个模型的入口函数和默认参数,对以下代码进行修改。
with slim.arg_scope(vgg.vgg_arg_scope()):
out, end_points = vgg.vgg_16(inputs=input_images, num_classes=labels_nums, dropout_keep_prob=keep_prob, is_training=is_training)
接下来定义损失函数、评价函数和优化器
tf.losses.softmax_cross_entropy(onehot_labels=input_labels, logits=out)#交叉熵
loss = tf.losses.get_total_loss(add_regularization_losses=True)#添加正则化损失
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(input_labels, 1)), tf.float32))
# Specify the optimization scheme:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=base_lr)
至此,已经定义了一个完整的计算图,接下来是执行计算的代码
train_op = slim.learning.create_train_op(total_loss=loss,optimizer=optimizer)
saver = tf.train.Saver()
max_acc=0.0
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(max_steps+1):
batch_input_images, batch_input_labels = sess.run([train_images_batch, train_labels_batch])
_, train_loss = sess.run([train_op, loss], feed_dict={input_images:batch_input_images,
input_labels:batch_input_labels,
keep_prob:0.5, is_training:True})
# train测试(这里仅测试训练集的一个batch)
if i%train_log_step == 0:
train_acc = sess.run(accuracy, feed_dict={input_images:batch_input_images,
input_labels: batch_input_labels,
keep_prob:1.0, is_training: False})
print ("%s: Step [%d] train Loss : %f, training accuracy : %g" % (datetime.now(), i, train_loss, train_acc))
# val测试(测试全部val数据)
if i%val_log_step == 0:
mean_loss, mean_acc=net_evaluation(sess, loss, accuracy, val_images_batch, val_labels_batch,val_nums)
print ("%s: Step [%d] val Loss : %f, val accuracy : %g" % (datetime.now(), i, mean_loss, mean_acc))
# 模型保存:每迭代snapshot次或者最后一次保存模型
if (i %snapshot == 0 and i >0)or i == max_steps:
print('-----save:{}-{}'.format(snapshot_prefix,i))
saver.save(sess, snapshot_prefix, global_step=i)
# 保存val准确率最高的模型
if mean_acc>max_acc and mean_acc>0.5:
max_acc=mean_acc
path = os.path.dirname(snapshot_prefix)
best_models=os.path.join(path,'best_models_{}_{:.4f}.ckpt'.format(i,max_acc))
print('------save:{}'.format(best_models))
saver.save(sess, best_models)
coord.request_stop()
coord.join(threads)
启动训练查看log信息:
2018-11-14 08:09:36.191508: Step [0] train Loss : 2.460573, training accuracy : 0.125
2018-11-14 08:10:17.346742: Step [0] val Loss : 4.058232, val accuracy : 0.200269
2018-11-14 08:14:46.037606: Step [100] train Loss : 1.979604, training accuracy : 0.1875
2018-11-14 08:19:14.251063: Step [200] train Loss : 1.918397, training accuracy : 0.1875
2018-11-14 08:19:55.207292: Step [200] val Loss : 870.869812, val accuracy : 0.198925
2018-11-14 08:24:19.959178: Step [300] train Loss : 1.228397, training accuracy : 0.3125
2018-11-14 08:28:42.778443: Step [400] train Loss : 1.846802, training accuracy : 0.25
2018-11-14 08:29:24.461384: Step [400] val Loss : 268.398010, val accuracy : 0.200269
2018-11-14 08:33:37.340099: Step [500] train Loss : 1.635928, training accuracy : 0.375
2018-11-14 08:37:49.819849: Step [600] train Loss : 1.517327, training accuracy : 0.25
2018-11-14 08:38:28.789897: Step [600] val Loss : 214.456604, val accuracy : 0.200941
2018-11-14 08:42:42.430475: Step [700] train Loss : 1.086831, training accuracy : 0.375
2018-11-14 08:46:57.851332: Step [800] train Loss : 1.707227, training accuracy : 0.3125
2018-11-14 08:47:37.021354: Step [800] val Loss : 65.510918, val accuracy : 0.203629
2018-11-14 08:51:49.277937: Step [900] train Loss : 1.176614, training accuracy : 0.25
2018-11-14 08:56:04.265438: Step [1000] train Loss : 1.306993, training accuracy : 0.4375
2018-11-14 08:56:43.866303: Step [1000] val Loss : 40.841892, val accuracy : 0.229167
·······
2018-11-14 11:21:56.745533: Step [4300] train Loss : 0.709980, training accuracy : 0.625
2018-11-14 11:25:52.476962: Step [4400] train Loss : 0.706520, training accuracy : 0.625
2018-11-14 11:26:30.562026: Step [4400] val Loss : 2.057459, val accuracy : 0.605511
2018-11-14 11:31:49.112498: Step [4500] train Loss : 0.725172, training accuracy : 0.625
2018-11-14 11:38:41.109449: Step [4600] train Loss : 0.711620, training accuracy : 1
2018-11-14 11:39:44.186370: Step [4600] val Loss : 1.891007, val accuracy : 0.655914
2018-11-14 11:46:32.496127: Step [4700] train Loss : 0.812459, training accuracy : 0.8125
2018-11-14 11:53:19.712461: Step [4800] train Loss : 0.728329, training accuracy : 0.5625
2018-11-14 11:54:24.584681: Step [4800] val Loss : 1.948687, val accuracy : 0.629704
2018-11-14 12:01:11.850038: Step [4900] train Loss : 0.706246, training accuracy : 0.625
2018-11-14 12:07:54.438784: Step [5000] train Loss : 0.716349, training accuracy : 0.6875
2018-11-14 12:08:59.608909: Step [5000] val Loss : 1.946240, val accuracy : 0.641801
2018-11-14 12:15:40.652675: Step [5100] train Loss : 0.714592, training accuracy : 0.6875
2018-11-14 12:22:25.825326: Step [5200] train Loss : 0.711778, training accuracy : 0.75
2018-11-14 12:23:29.359956: Step [5200] val Loss : 2.166813, val accuracy : 0.660618
2018-11-14 12:30:16.579020: Step [5300] train Loss : 0.711300, training accuracy : 0.625
2018-11-14 12:37:02.258133: Step [5400] train Loss : 0.702036, training accuracy : 0.75
2018-11-14 12:38:05.865874: Step [5400] val Loss : 2.106956, val accuracy : 0.646505
·······
2018-11-14 15:28:55.589414: Step [9000] train Loss : 0.701351, training accuracy : 1
2018-11-14 15:29:33.319283: Step [9000] val Loss : 1.496502, val accuracy : 0.816532
2018-11-14 15:33:30.661300: Step [9100] train Loss : 0.699581, training accuracy : 0.8125
2018-11-14 15:37:27.817035: Step [9200] train Loss : 0.837011, training accuracy : 0.9375
2018-11-14 15:38:06.510063: Step [9200] val Loss : 1.485197, val accuracy : 0.822581
2018-11-14 15:42:03.880014: Step [9300] train Loss : 0.698662, training accuracy : 1
2018-11-14 15:46:00.088570: Step [9400] train Loss : 0.705768, training accuracy : 1
2018-11-14 15:46:38.134584: Step [9400] val Loss : 1.292768, val accuracy : 0.848118
------save:models_V3/best_models_9400_0.8481.ckpt
2018-11-14 15:50:35.086152: Step [9500] train Loss : 0.701549, training accuracy : 1
2018-11-14 15:54:31.618631: Step [9600] train Loss : 0.701810, training accuracy : 1
2018-11-14 15:55:09.838977: Step [9600] val Loss : 1.415559, val accuracy : 0.822581
2018-11-14 15:59:05.713706: Step [9700] train Loss : 0.723139, training accuracy : 1
2018-11-14 16:03:02.386047: Step [9800] train Loss : 0.700476, training accuracy : 1
2018-11-14 16:03:39.967466: Step [9800] val Loss : 1.525427, val accuracy : 0.829973
2018-11-14 16:07:36.202739: Step [9900] train Loss : 0.699305, training accuracy : 1
2018-11-14 16:11:30.578002: Step [10000] train Loss : 0.698061, training accuracy : 1
2018-11-14 16:12:08.364591: Step [10000] val Loss : 1.500178, val accuracy : 0.8125
-----save:models_V3/model.ckpt-10000
从log信息可以看到,开始训练集和测试集的识别率都特别低,随着训练,识别率逐渐上升,虽然数据集比较小,但val识别率最终稳定在80%以上。
predict_test.py比较简单,就不贴代码了。去网上找几张图片放到test_image文件夹中,运行predict_test.py就可以。
test_image内的图片:
预测结果:
test_image/1.jpg is: pre labels:[2],name:['animal'] score: [ 0.99848539]
test_image/2.jpg is: pre labels:[0],name:['flower'] score: [ 0.9996984]
test_image/3.jpg is: pre labels:[1],name:['guitar'] score: [ 0.99989831]
test_image/4.jpg is: pre labels:[3],name:['houses'] score: [ 0.99999392]
test_image/5.jpg is: pre labels:[2],name:['animal'] score: [ 0.76743358]
test_image/6.jpg is: pre labels:[4],name:['plane'] score: [ 0.7605322]
本文内容编辑:郑杜磊