[AI教程]TensorFlow入门:使用TF-slim的模型做图像分类

文章目录

  • 1.简介
  • 2.文件说明
  • 3.训练过程
    • 3.1数据预处理
    • 3.2训练
    • 3.3测试

1.简介

slim是TensorFlow的一个轻量级库,它基于TensorFlow实现了高层封装,将网络、loss、正则化等概念有调理的组织起来,而不是像原生tensorflow底层接口编程那样,到处充满了超参、网络定义、训练循环等。
例如,定义一个卷积:

with tf.name_scope('conv_a') as scope:
  kernel = tf.Variable(tf.truncated_normal([5, 5, 32, 64], dtype=tf.float32,
                                           stddev=1e-1), name='weights')
  conv = tf.nn.conv2d(input, kernel, [1, 1, 1, 1], padding='SAME')
  biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32),
                       trainable=True, name='biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name=scope)

可以看到,其中包括了命名空间,权重变量,偏置值变量,激活函数,网络等等重要概念,如果要修改的话会相当麻烦。而使用slim完成同样的卷积,只需要一行代码:

net = slim.conv2d(input, 64, [5, 5], scope='conv_a')

除了通过消除模板代码,允许用户更紧凑地定义模型之外,slim还封装了计算机视觉的几个常见模型(AlexNet,VGGNet,GoogLeNet,ResNet),对于普通用户可以直接当做黑盒来调用,对于有研究需求的用户也可以基于封装以各种方式进行修改和扩展,省去搭建模型的时间。
如果想要学习slim,这篇博客有详细的说明:TensorFlow - TF-Slim 封装模块

2.文件说明

本篇文章提供的几个文件:
[AI教程]TensorFlow入门:使用TF-slim的模型做图像分类_第1张图片

create_tfrecord.py 定义了操作tfrecords文件的一些接口
train_model.py 训练模型
predict_test.py 测试模型
slim TF-Slim的拷贝
test_image 存放测试图片
dataset 存放数据集,它的结构如下:
[AI教程]TensorFlow入门:使用TF-slim的模型做图像分类_第2张图片
train中是训练集,val是验证集,标签分别保存在相应的txt文件中
文件下载:Slim模型分类

3.训练过程

3.1数据预处理

TensorFlow的训练过程就是数据在网络中流动的过程,官方提供了三种数据读取方式,分别是:

  1. Feeding。通过Python直接读入数据
  2. Reading from files。从文件读取数据
  3. Preloaded data。将数据以constant或者variable的方式直接存储在运算图中

在数据量较大的情况下,官方推荐第二种标准的TensorFlow格式(Standard TensorFlow format)存储数据,文件名后缀为tfrecords。本文提供的create_tfrecord.py中提供了几个重要的函数,对于一般的图像分类问题可以直接使用。
本篇文章以VGG16举例。VGG16模型要求数据大小为224x224,设置create_tfrecord.py的参数运行可以直接得到train224.tfrecords和val224.tfrecords。

if __name__ == '__main__':
    # 参数设置
    resize_height = 224  # 指定存储图片高度
    resize_width = 224  # 指定存储图片宽度
    shuffle=True
    log=5 #打印信息的间隔
    # 产生train.record文件
    image_dir='dataset/train'
    train_labels = 'dataset/train.txt'  # 图片路径
    train_record_output = 'dataset/record/train224.tfrecords'
    create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)
    train_nums=get_example_nums(train_record_output)
    print("save train example nums={}".format(train_nums))

    # 产生val.record文件
    image_dir='dataset/val'
    val_labels = 'dataset/val.txt'  # 图片路径
    val_record_output = 'dataset/record/val224.tfrecords'
    create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)
    val_nums=get_example_nums(val_record_output)
    print("save val example nums={}".format(val_nums))

3.2训练

上一步已经将数据制作成了tfrecords文件,以下是train_model.py的详细说明:

首先在代码开头定义关键参数

import tensorflow as tf 
import numpy as np 
import pdb
import os
from datetime import datetime
import slim.nets.vgg as vgg
from create_tf_record import *
import tensorflow.contrib.slim as slim

#################################
labels_nums = 5  # 类别个数
batch_size = 16  #
resize_height = 224  # 指定存储图片高度
resize_width = 224  # 指定存储图片宽度
depths = 3
data_shape = [batch_size, resize_height, resize_width, depths]
train_record_file='dataset/record/train224.tfrecords'
val_record_file='dataset/record/val224.tfrecords'

train_log_step=100
base_lr = 0.05  # 学习率
max_steps = 10000  # 迭代次数
train_param=[base_lr,max_steps]

val_log_step=200
snapshot=2000#保存文件间隔
snapshot_prefix='models/model.ckpt'
###################################

TensorFlow的程序分为两阶段:定义计算图和执行计算,很多数据是在执行计算的时候才会传入,因此引入了占位符的概念,tf.placeholder()是TensorFlow的占位符节点,由placeholder方法创建,其也是一种常量,但是由用户在调用run方法时传递的,可以简单理解为形参,用于定义过程,在执行的时候再赋具体的值。keep_prob是dropout的概率。is_training=True时,填充train数据进行训练过程,is_training=False时,填充val数据进行验证过程。

# 定义input_images为图片数据
input_images = tf.placeholder(dtype=tf.float32, shape=[None, resize_height, resize_width, depths], name='input')
# 定义input_labels为labels数据
input_labels = tf.placeholder(dtype=tf.int32, shape=[None, labels_nums], name='label')

keep_prob = tf.placeholder(tf.float32,name='keep_prob')
is_training = tf.placeholder(tf.bool, name='is_training')

接下来定义训练函数

def train(train_record_file,
          train_log_step,
          train_param,
          val_record_file,
          val_log_step,
          labels_nums,
          data_shape,
          snapshot,
          snapshot_prefix):
    '''
    :param train_record_file: 训练的tfrecord文件
    :param train_log_step: 显示训练过程log信息间隔
    :param train_param: train参数
    :param val_record_file: 验证的tfrecord文件
    :param val_log_step: 显示验证过程log信息间隔
    :param val_param: val参数
    :param labels_nums: labels数
    :param data_shape: 输入数据shape
    :param snapshot: 保存模型间隔
    :param snapshot_prefix: 保存模型文件的前缀名
    :return:
    '''

create_tfrecords.py中提供了read_records()用来读取tfrecords文件和get_batch_images()获取批数据。

train_images, train_labels = read_records(train_record_file, resize_height, resize_width, type='normalization')
train_images_batch, train_labels_batch = get_batch_images(train_images, train_labels,
                                                              batch_size=batch_size, labels_nums=labels_nums,
                                                              one_hot=True, shuffle=True)
# val数据,验证数据可以不需要打乱数据
val_images, val_labels = read_records(val_record_file, resize_height, resize_width, type='normalization')
val_images_batch, val_labels_batch = get_batch_images(val_images, val_labels,
                                                          batch_size=batch_size, labels_nums=labels_nums,
                                                          one_hot=True, shuffle=False)

封装的模型都是用tf.contrib.slim写的,slim.arg_scope用于为tensorflow里的layer函数提供默认参数值。因此,在定义模型之前,需要设置模型的默认参数。训练不同的模型,只需要替换import slim.nets.vgg as vgg,替换如下:

import slim.nets.vgg as vgg
#训练GoogLeNet v1版本改为
import slim.nets.inception_v1 as inception_v1
#训练ResNet v1版本改为
import slim.nets.resnet_v1 as resnet_v1
#训练AlexNet 改为
import slim.nets.alexnet as alexnet

另外,每个模型的入口函数和默认参数不同,训练不同模型时也需要做相应的修改。封装的模型文件在slim/nets目录下,可以查询每个模型的入口函数和默认参数,对以下代码进行修改。

 with slim.arg_scope(vgg.vgg_arg_scope()):
        out, end_points = vgg.vgg_16(inputs=input_images, num_classes=labels_nums, dropout_keep_prob=keep_prob, is_training=is_training)

接下来定义损失函数、评价函数和优化器

tf.losses.softmax_cross_entropy(onehot_labels=input_labels, logits=out)#交叉熵

loss = tf.losses.get_total_loss(add_regularization_losses=True)#添加正则化损失
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(input_labels, 1)), tf.float32))
# Specify the optimization scheme:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=base_lr)

至此,已经定义了一个完整的计算图,接下来是执行计算的代码

train_op = slim.learning.create_train_op(total_loss=loss,optimizer=optimizer)
saver = tf.train.Saver()
max_acc=0.0
with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   sess.run(tf.local_variables_initializer())
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(sess=sess, coord=coord)
   for i in range(max_steps+1):
       batch_input_images, batch_input_labels = sess.run([train_images_batch, train_labels_batch])
       _, train_loss = sess.run([train_op, loss], feed_dict={input_images:batch_input_images,
                                                                      input_labels:batch_input_labels,
                                                                      keep_prob:0.5, is_training:True})
       # train测试(这里仅测试训练集的一个batch)
       if i%train_log_step == 0:
          train_acc = sess.run(accuracy, feed_dict={input_images:batch_input_images,
                                                          input_labels: batch_input_labels,
                                                          keep_prob:1.0, is_training: False})
          print ("%s: Step [%d]  train Loss : %f, training accuracy :  %g" % (datetime.now(), i, train_loss, train_acc))

       # val测试(测试全部val数据)
       if i%val_log_step == 0:
          mean_loss, mean_acc=net_evaluation(sess, loss, accuracy, val_images_batch, val_labels_batch,val_nums)
          print ("%s: Step [%d]  val Loss : %f, val accuracy :  %g" % (datetime.now(), i, mean_loss, mean_acc))

       # 模型保存:每迭代snapshot次或者最后一次保存模型
       if (i %snapshot == 0 and i >0)or i == max_steps:
          print('-----save:{}-{}'.format(snapshot_prefix,i))
          saver.save(sess, snapshot_prefix, global_step=i)
       # 保存val准确率最高的模型
       if mean_acc>max_acc and mean_acc>0.5:
          max_acc=mean_acc
          path = os.path.dirname(snapshot_prefix)
                best_models=os.path.join(path,'best_models_{}_{:.4f}.ckpt'.format(i,max_acc))
          print('------save:{}'.format(best_models))
          saver.save(sess, best_models)
   coord.request_stop()
   coord.join(threads)

启动训练查看log信息:

2018-11-14 08:09:36.191508: Step [0]  train Loss : 2.460573, training accuracy :  0.125
2018-11-14 08:10:17.346742: Step [0]  val Loss : 4.058232, val accuracy :  0.200269
2018-11-14 08:14:46.037606: Step [100]  train Loss : 1.979604, training accuracy :  0.1875
2018-11-14 08:19:14.251063: Step [200]  train Loss : 1.918397, training accuracy :  0.1875
2018-11-14 08:19:55.207292: Step [200]  val Loss : 870.869812, val accuracy :  0.198925
2018-11-14 08:24:19.959178: Step [300]  train Loss : 1.228397, training accuracy :  0.3125
2018-11-14 08:28:42.778443: Step [400]  train Loss : 1.846802, training accuracy :  0.25
2018-11-14 08:29:24.461384: Step [400]  val Loss : 268.398010, val accuracy :  0.200269
2018-11-14 08:33:37.340099: Step [500]  train Loss : 1.635928, training accuracy :  0.375
2018-11-14 08:37:49.819849: Step [600]  train Loss : 1.517327, training accuracy :  0.25
2018-11-14 08:38:28.789897: Step [600]  val Loss : 214.456604, val accuracy :  0.200941
2018-11-14 08:42:42.430475: Step [700]  train Loss : 1.086831, training accuracy :  0.375
2018-11-14 08:46:57.851332: Step [800]  train Loss : 1.707227, training accuracy :  0.3125
2018-11-14 08:47:37.021354: Step [800]  val Loss : 65.510918, val accuracy :  0.203629
2018-11-14 08:51:49.277937: Step [900]  train Loss : 1.176614, training accuracy :  0.25
2018-11-14 08:56:04.265438: Step [1000]  train Loss : 1.306993, training accuracy :  0.4375
2018-11-14 08:56:43.866303: Step [1000]  val Loss : 40.841892, val accuracy :  0.229167
·······
2018-11-14 11:21:56.745533: Step [4300]  train Loss : 0.709980, training accuracy :  0.625
2018-11-14 11:25:52.476962: Step [4400]  train Loss : 0.706520, training accuracy :  0.625
2018-11-14 11:26:30.562026: Step [4400]  val Loss : 2.057459, val accuracy :  0.605511
2018-11-14 11:31:49.112498: Step [4500]  train Loss : 0.725172, training accuracy :  0.625
2018-11-14 11:38:41.109449: Step [4600]  train Loss : 0.711620, training accuracy :  1
2018-11-14 11:39:44.186370: Step [4600]  val Loss : 1.891007, val accuracy :  0.655914
2018-11-14 11:46:32.496127: Step [4700]  train Loss : 0.812459, training accuracy :  0.8125
2018-11-14 11:53:19.712461: Step [4800]  train Loss : 0.728329, training accuracy :  0.5625
2018-11-14 11:54:24.584681: Step [4800]  val Loss : 1.948687, val accuracy :  0.629704
2018-11-14 12:01:11.850038: Step [4900]  train Loss : 0.706246, training accuracy :  0.625
2018-11-14 12:07:54.438784: Step [5000]  train Loss : 0.716349, training accuracy :  0.6875
2018-11-14 12:08:59.608909: Step [5000]  val Loss : 1.946240, val accuracy :  0.641801
2018-11-14 12:15:40.652675: Step [5100]  train Loss : 0.714592, training accuracy :  0.6875
2018-11-14 12:22:25.825326: Step [5200]  train Loss : 0.711778, training accuracy :  0.75
2018-11-14 12:23:29.359956: Step [5200]  val Loss : 2.166813, val accuracy :  0.660618
2018-11-14 12:30:16.579020: Step [5300]  train Loss : 0.711300, training accuracy :  0.625
2018-11-14 12:37:02.258133: Step [5400]  train Loss : 0.702036, training accuracy :  0.75
2018-11-14 12:38:05.865874: Step [5400]  val Loss : 2.106956, val accuracy :  0.646505

·······
2018-11-14 15:28:55.589414: Step [9000]  train Loss : 0.701351, training accuracy :  1
2018-11-14 15:29:33.319283: Step [9000]  val Loss : 1.496502, val accuracy :  0.816532
2018-11-14 15:33:30.661300: Step [9100]  train Loss : 0.699581, training accuracy :  0.8125
2018-11-14 15:37:27.817035: Step [9200]  train Loss : 0.837011, training accuracy :  0.9375
2018-11-14 15:38:06.510063: Step [9200]  val Loss : 1.485197, val accuracy :  0.822581
2018-11-14 15:42:03.880014: Step [9300]  train Loss : 0.698662, training accuracy :  1
2018-11-14 15:46:00.088570: Step [9400]  train Loss : 0.705768, training accuracy :  1
2018-11-14 15:46:38.134584: Step [9400]  val Loss : 1.292768, val accuracy :  0.848118
------save:models_V3/best_models_9400_0.8481.ckpt
2018-11-14 15:50:35.086152: Step [9500]  train Loss : 0.701549, training accuracy :  1
2018-11-14 15:54:31.618631: Step [9600]  train Loss : 0.701810, training accuracy :  1
2018-11-14 15:55:09.838977: Step [9600]  val Loss : 1.415559, val accuracy :  0.822581
2018-11-14 15:59:05.713706: Step [9700]  train Loss : 0.723139, training accuracy :  1
2018-11-14 16:03:02.386047: Step [9800]  train Loss : 0.700476, training accuracy :  1
2018-11-14 16:03:39.967466: Step [9800]  val Loss : 1.525427, val accuracy :  0.829973
2018-11-14 16:07:36.202739: Step [9900]  train Loss : 0.699305, training accuracy :  1
2018-11-14 16:11:30.578002: Step [10000]  train Loss : 0.698061, training accuracy :  1
2018-11-14 16:12:08.364591: Step [10000]  val Loss : 1.500178, val accuracy :  0.8125
-----save:models_V3/model.ckpt-10000

从log信息可以看到,开始训练集和测试集的识别率都特别低,随着训练,识别率逐渐上升,虽然数据集比较小,但val识别率最终稳定在80%以上。

3.3测试

predict_test.py比较简单,就不贴代码了。去网上找几张图片放到test_image文件夹中,运行predict_test.py就可以。
test_image内的图片:
[AI教程]TensorFlow入门:使用TF-slim的模型做图像分类_第3张图片
预测结果:

test_image/1.jpg is: pre labels:[2],name:['animal'] score: [ 0.99848539]
test_image/2.jpg is: pre labels:[0],name:['flower'] score: [ 0.9996984]
test_image/3.jpg is: pre labels:[1],name:['guitar'] score: [ 0.99989831]
test_image/4.jpg is: pre labels:[3],name:['houses'] score: [ 0.99999392]
test_image/5.jpg is: pre labels:[2],name:['animal'] score: [ 0.76743358]
test_image/6.jpg is: pre labels:[4],name:['plane'] score: [ 0.7605322]

本文内容编辑:郑杜磊

你可能感兴趣的:([AI教程]TensorFlow入门:使用TF-slim的模型做图像分类)