本篇更多的是在代码实战方向,不会涉及太多的理论。本文主要针对TensorFlow和卷积神经网络有一定基础的同学,并对图像处理有一定的了解。
阅读本文你大概需要以下知识:
1.TensorFlow基础
2.TensorFlow实现卷积神经网络的前向传播过程
3.TFRecord数据格式
4.Dataset的使用
5.Slim的使用
好了废话不多说,下面开始。
一.数据准备
首先我们需要有一个让我们训练的数据集,这里谷歌已经帮我们做好了。这里要把数据集下载下来,打开命令行,执行如下命令:
wget http://download.tensorflow.org/example_image/flower_photo.tgz
//解压
tar xzf flower_photos.tgz
这里需要注意的是,文件最好是下载到你的工程目录下方便你的读取。什么?你还不会搭建TensorFlow程序?请移步https://www.tensorflow.org/install/
选择自己的操作系统,在这里我的是macOS。我使用的是Virtualenv来搭建TensorFlow运行环境。
数据集下载并解压后,我们可以看到大概是这个样子
好了,数据有了?接下来该怎么办呢?当然是把数据进行预处理拉,你不会觉得我们的TensorFlow可以直接识别这些图片进行训练吧,hhhhhh。
二.数据预处理
接下来我们在目录下新建pre_data.python文件。TensorFlow对图片做处理一般是生成TFRecord文件。什么是TFRecord?后面我们会讲到。
首先我们要引入我们需要的库。
# glob模块的主要方法就是glob,该方法返回所有匹配的文件路径列表(list)
import glob
#os.path生成路径方便glob获取
import os.path
#这里主要用到随机数
import numpy as np
#引入tensorflow框架
import tensorflow as tf
#引入gflie对图片做处理
from tensorflow.python.platform import gfile
相关库在我们这个程序中的功能都作了简单介绍,下面用到的时候我们会更加详细的说明。
大家都知道我们的数据集一般分训练,测试和验证数据集。观察上面的数据集,谷歌只是给出了每一种花的图片,并没有给去哪些我训练,哪些是测试,哪些是验证数据集。所以在这里我们要进行划分。
#输入图片地址
INPUT_DATA = '../../flower_photos'
#训练数据集
OUTPUT_FILE = './path/to/output.tfrecords'
#测试数据集
OUTPUT_TEST_FILE = './path/to/output_test.tfrecords'
#验证数据集
OUTPUT_VALIDATION_FILE = './path/to/output_validation.tfrecords'
#测试数据和验证数据的比例
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10
关于VALIDATION_PERCENTAGE和TEST_PERCENTAGE这两个常量,我们在后面的例子会给出。
下面我们就来定义处理数据的方法:
def create_image_lists(sess,testing_percentage,validation_percentage):
#拿到INPUT_DATA文件夹下的所有目录(包括root)
sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
#如果是root_dir不需要做处理
is_root_dir = True
#定义图片对应的标签,从0-4分别代表不同的花
current_label = 0
#写入TFRecord的数据需要首先定义writer
#这里定义三个writer分别存储训练,测试和验证数据
writer = tf.python_io.TFRecordWriter(OUTPUT_FILE)
writer_test = tf.python_io.TFRecordWriter(OUTPUT_TEST_FILE)
writer_validation = tf.python_io.TFRecordWriter(OUTPUT_VALIDATION_FILE)
#循环目录
for sub_dir in sub_dirs:
if is_root_dir:
#跳过根目录
is_root_dir = False
continue
#定义空数组来装图片路径
file_list = []
#生成查找路径
dir_name = os.path.basename(sub_dir)
file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + "jpg")
# extend合并两个数组
# glob模块的主要方法就是glob,该方法返回所有匹配的文件路径列表(list)
# 比如:glob.glob(r’c:*.txt’) 这里就是获得C盘下的所有txt文件
file_list.extend(glob.glob(file_glob))
#路径下没有文件就跳过,不继续操作
if not file_list: continue
#这里我定义index来打印当前进度
index = 0
#file_list此时是图片路径列表
for file_name in file_list:
#使用gfile从路径中读取图片
image_raw_data = gfile.FastGFile(file_name, 'rb').read()
#对图像解码,解码结果为一个张量
image = tf.image.decode_jpeg(image_raw_data)
#对图像矩阵进行归一化处理
#因为为了将图片数据能够保存到 TFRecord 结构体中
#所以需要将其图片矩阵转换成 string
#所以为了在使用时能够转换回来
#这里确定下数据格式为 tf.float32
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# 将图片转化成299*299方便模型处理
image = tf.image.resize_images(image, [299, 299])
#为了拿到图片的真实数据这里我们要运行一个session op
image_value = sess.run(image)
pixels = image_value.shape[1]
#存储在TFrecord里面的不能是array的形式
#所以我们需要利用tostring()将上面的矩阵
#转化成字符串
#再通过tf.train.BytesList转化成可以存储的形式
image_raw = image_value.tostring()
#存到features
#随机划分测试集和训练集
#这里存入TFRecord三个数据,图像的pixels像素
#图像原张量,这里我们需要转成string
#以及当前图像对应的标签
example = tf.train.Example(features=tf.train.Features(feature={
'pixels': _int64_feature(pixels),
'label': _int64_feature(current_label),
'image_raw': _bytes_feature(image_raw)
}))
chance = np.random.randint(100)
#随机划分数据集
if chance < validation_percentage:
writer_validation.write(example.SerializeToString())
elif chance < (testing_percentage+validation_percentage):
writer_test.write(example.SerializeToString())
else:
writer.write(example.SerializeToString())
# print('example',index)
index = index + 1
#每一个文件夹下的所有图片都是一个类别
#所以这里每遍历完一个文件夹,标签就增加1
current_label += 1
writer.close()
writer_validation.close()
writer_test.close()
运行上述程序需要一定时间,我的电脑比较烂,大概跑了三十分钟左右。这时候在你的./path/to目录下可以看到output.tfrecords,output_test.tfrecords,output_validation.tfrecords三个文件,分别存放了训练,测试和验证数据集。上述代码将所有图片划分成训练、验证和测试数据集。并且把图片从原始的jpg格式转换成inception-v3模型需要的299 * 299 * 3的数字矩阵。在数据处理完毕之后,通过以下命令可以下载谷歌提供好的Inception_v3模型。
wget http://download.tensorflow.org/models/inception_v3_2016_08_26.tar.gz
//解压之后可以得到训练好的模型文件inception_v3.ckpt
tar xzf inception_v3_2016_08
二.训练
当新的数据集和已经训练好的模型都准备好之后,我们来写代码在谷歌inception_v3的基础上训练新数据集。
首先同样我们导入相关的库并且定义相关常量。在这里我们通过slim工具来直接加载模型,而不用自己再定义前向传播过程。
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
# 加载通过TensorFlow-Silm定义好的 inception_v3模型
import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3
# 输入数据文件
INPUT_DATA = './path/to/output.tfrecords'
# 验证数据集
VALIDATION_DATA = './path/to/output_validation.tfrecords'
# 保存训练好的模型的路径
ls = './path/to/save_model'
# 谷歌提供的训练好的模型文件地址
CKPT_FILE = './path/to/inception_v3.ckpt'
TRAIN_FILE = './path/to/save_model'
# 定义训练中使用的参数
LEARNING_RATE = 0.01
#组合batch的大小
BATCH = 32
#用于one_hot函数输出概率分布
N_CLASSES = 5
#打乱顺序,并设置出队和入队中元素最少的个数,这里是10000个
shuffle_buffer = 10000
# 不需要从谷歌模型中加载的参数,这里就是最后的全连接层。因为输出类别不一样,所以最后全连接层的参数也不一样
CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
# 需要训练的网络层参数 这里就是最后的全连接层
TRAINABLE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
接下来我们定义几个辅助方法。首先因为我们的数据存在TFRecord里,需要定义方法从TFRecord解析数据。
def parse(record):
features = tf.parse_single_example(
record,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'pixels': tf.FixedLenFeature([], tf.int64)
}
)
#decode_raw用于解析TFRecord里面的字符串
decoded_image = tf.decode_raw(features['image_raw'], tf.uint8)
label = features['label']
#要注意这里的decoded_image并不能直接进行reshape操作
#之前我们在存储的时候,把图片进行了tostring()操作
#这会导致图片的长度在原来基础上*8
#后面我们要用到numpy的fromstring来处理
return decoded_image, label
接下来定义两个方法。因为我们已经下载了谷歌训练好的inception_v3模型的参数,下面我们需要定义两个方法从里面加载参数。
#直接从inception_v3.ckpt中读取的参数
def get_tuned_variables():
#strip删除头尾字符,默认为空格
exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(",")]
variables_to_restore = []
#这里给出了所有slim模型下的参数
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
return variables_to_restore
#需要重新训练的参数
def get_trainable_variables():
#strip删除头尾字符,默认为空格
scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(",")]
variables_to_train = []
# 枚举所有需要训练的参数前缀,并通过这些前缀找到所有的参数。
for scope in scopes:
#从TRAINABLE_VARIABLES集合中获取名为scope的变量
#也就是我们需要重新训练的参数
variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope)
variables_to_train.extend(variables)
return variables_to_train
这里我们就写完了所需要的工具函数,接下来我们定义主函数。主函数主要完成数据读取,模型定义,通过模型得出前向传播结果,通过损失函数计算损失,最后把损失交给优化器做处理。首先我们先来完成数据读取的代码,这里我们使用的是TensorFlow高层API Dataset。不清楚的可以去看一下Dataset的用法。
这里我们在训练的同时也对模型做了验证。所以我们需要加载训练和验证数据
#读取测试数据
#利用TFRecordDataset读取TFRecord文件
dataset = tf.data.TFRecordDataset([INPUT_DATA])
#解析TFRecord
dataset = dataset.map(parse)
#把数据打乱顺序并组装成batch
dataset = dataset.shuffle(shuffle_buffer).batch(BATCH)
#定义数据重复的次数
NUM_EPOCHS = 10
dataset = dataset.repeat(NUM_EPOCHS)
#定义迭代器来获取处理后的数据
iterator = dataset.make_one_shot_iterator()
#迭代器开始迭代
img, label = iterator.get_next()
#读取验证数据(同上)
valida_dataset = tf.data.TFRecordDataset([VALIDATION_DATA])
valida_dataset = valida_dataset.map(parse)
valida_dataset = valida_dataset.batch(BATCH)
valida_iterator = valida_dataset.make_one_shot_iterator()
valida_img,valida_label = valida_iterator.get_next()
#定义inception-v3的输入,images为输入图片,label为每一张图片对应的标签
#再解释下每一个维度 None为batch的大小,299为图片大小,3为通道
images = tf.placeholder(tf.float32,[None,299,299,3],name='input_images')
labels = tf.placeholder(tf.int64,[None],name='labels')
要注意上述定义的只是tensorflow的张量,保存的只是计算过程并没有具体的数据。只有运行session之后才会拿到具体的数据。
下面我们来通过slim加载inception-v3模型
#定义inception-v3模型结构 inception_v3.ckpt里只有参数的取值
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
#logits inception_v3前向传播得到的结果
logits,_ = inception_v3.inception_v3(images,num_classes=N_CLASSES)
#获取需要训练的变量
trainable_variables = get_trainable_variables()
#这里用交叉熵作为损失函数,注意一下tf.losses.softmax_cross_entropy的参数
# tf.losses.softmax_cross_entropy(
# onehot_labels, # 注意此处参数名就叫 onehot_labels
# logits,
# weights=1.0,
# label_smoothing=0,
# scope=None,
# loss_collection=tf.GraphKeys.LOSSES,
# reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
# )
#这里要把labels转成one_hot类型,logits就是神经网络的输出
tf.losses.softmax_cross_entropy(tf.one_hot(labels,N_CLASSES),logits,weights=1.0)
#把计算的损失交给优化器处理
train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss())
#计算正确率。
with tf.name_scope('evaluation'):
correct_prediction = tf.equal(tf.argmax(logits,1),labels)
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#定义加载模型的函数
load_fn = slim.assign_from_checkpoint_fn(CKPT_FILE,get_tuned_variables(),ignore_missing_vars=True)
#定义保存新的训练好的模型的函数
saver = tf.train.Saver()
with tf.Session() as sess:
#初始化所有变量
init = tf.global_variables_initializer()
sess.run(init)
print('Loading tuned variables from %s'%CKPT_FILE)
#加载谷歌已经训练好的模型
load_fn(sess)
step = 0;
#在这里我们用一个while来循环训练,直到dataset里没有数据就结束循环
while True:
try:
if step % 30 == 0 or step + 1 == STEPS:
#每30轮输出一次正确率
if step != 0:
#每30轮保存一次当前模型的参数,以便中途训练中断可以继续
saver.save(sess,TRAIN_FILE,global_step=step)
#运行session拿到真实图片的数据
valida_img_batch,valida_label_batch = sess.run([valida_img,valida_label])
#上面有提到TFRecord里图片数据被转成了string,在这里转回来
valida_img_batch = np.fromstring(valida_img_batch, dtype=np.float32)
#把图片张量拉成新的维度
valida_img_batch = tf.reshape(valida_img_batch, [32, 299, 299, 3])
#用session运行上述操作,得到处理后的图片张量
valida_img_batch = sess.run(valida_img_batch)
#把图片张量传到feed_dict算出正确率并显示
validation_accuracy = sess.run(evaluation_step,feed_dict={
images:valida_img_batch,
labels:valida_label_batch
})
print('Step %d: Validation accurary = %.1f%%'%(step,validation_accuracy*100.0))
#下面是对训练数据的操作,同上
img_batch,label_batch = sess.run([img,label])
img_batch = np.fromstring(img_batch, dtype=np.float32)
img_batch = tf.reshape(img_batch, [32,299, 299, 3])
img_batch = sess.run(img_batch)
sess.run(train_step,feed_dict={
images:img_batch,
labels:label_batch
})
#step仅仅用于记录
step = step + 1
except tf.errors.OutOfRangeError:
break
运行上述程序开始训练。在这里我暂时是使用cpu进行训练,训练过程大约3小时,可以得到类型下面的结果。
step 0:Validation accuracy = 12.5%
step 30:Validation accuracy = 22.2%
step 60:Validation accuracy = 63.2%
step 90:Validation accuracy = 79.8%
step 120:Validation accuracy = 86.4%
step 150:Validation accuracy = 88.5%
.....
以上就是我使用谷歌Inception-v3模型训练新的数据集的全部内容。