本章主要使用已有模型微调做二次训练。训练样本有8421张男性头像和11599张女性头像,测试样本由10张男性头像,10张女性头像构成。我们使用NASNet预训练模型微调一个二分类模型来识别图像性别。
文件结构如下:
- mydataset.py 处理样本数据集的代码
- model.py 加载预训练模型,并微调的模型代码
- train.py 训练代码
- test.py 测试代码
处理数据的mydataset.py在上一篇【系列学习】数据通用处理 —— tf工程化项目实战里已经分析过了,这一篇主要分析model.py和训练测试的代码。
因为代码量比较多,本篇主要目的在于描述清楚我们需要在model.py 和 训练测试时做哪些事情,所以不会贴很多完整代码,以建立起结构的认识为首要目的~
【model.py】
model文件主要是定义一些训练的式子,也就是后面提到的op,以便在train.py和test.py时可以直接通过sess.run无感知细节的调用。
通常model文件我们会定义一个和模型同名的类,比如我们这边可以定义一个MyNASNetModel的类。
这个类结构如下:
调用关系如下:
build_model是暴露给调用方的主函数,参数支持里基本包含了训练需要的所有参数,训练和测试数据的地址,以及模式mode的控制:
- mode='train'
- testdata_dir='./data/val'
- traindata_dir='./data/train'
- batch_size=32
- learning_rate1=0.001
- learning_rate2=0.001
代码支持3种模式:train/test/val,所以build_model的主框架就是一个3分支的if..else,每个分支处理的流程大同小异,可以总结为下:
- 先用tf.reset_default_graph()清理下图形堆栈,否则scope里的图形还是会保存在内存里的
- 根据参数里的训练/测试样本地址生成Dataset对象,并生成迭代器,把初始化Dataset的op存储为self.train_init_op/self.test_init_op留待调用方在session里run
- 接下来会根据模式分成不同的处理:
- 如果是train就调用build_model_train方法
- 否则直接调用MyNASNet得到预测的op,再直接传入global_step调用加载函数的函数load_cpk
- 如果是train加一步self.global_init = tf.global_variables_initializer();如果是test加一步self.build_acc_base(labels)
- tf.get_default_graph().finalize()
可以看到1,2,4,5都是一些通用处理,主要差异在于第3步。第3步如果只是预测,就只需要拿到预测的op,如果是训练,就要调用训练相关的函数处理更多。
我们先分析一下训练和预测都用到的一些通用函数,比如:load_cpk 和 MyNASNet.
【load_cpk — 加载模型函数】
加载模型根据begin传参的不同分为两种处理模式:
1.begin为0或者不传时是模型的初始化:
save_path = r'./train_nasnet'
if not os.path.exists(save_path):
print("no model path")
saver = tf.train.Saver(max_to_keep=1)
return saver, save_path
创建了一个存储模型的地址save_path,和保存加载模型的构造器saver。
2.begin不为0时是模型的加载:
kpt = tf.train.latest_checkpoint(save_path)
print("load model", kpt)
startepo = 0
if kpt!=None:
saver.restore(sess, kpt)
ind = kpt.find("-")
startepo = int(kpt[ind+1:])
print("global_step=", global_step.eval(),startepo)
return startepo
会从存储模型的地址去读取最新训练的模型,恢复参数,并且返回训练步数。
加载模型里比较重要的需要加深印象的API有3个:
- tf.train.Saver(max_to_keep=1)
- tf.train.latest_checkpoint(save_path)
- saver.restore(sess, kpt)
【构建预测op — MyNASNet】
arg_scope = nasnet.nasnet_mobile_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = nasnet.build_nasnet_mobile(images, num_classes=self.num_classes+1, is_training=is_training)
global_step = tf.train.get_or_create_global_step()
return logits, end_points, global_step
这个就是一个常规的对slim的对应模型api的调用得到logits, end_points之类,顺便在这边初始化了一下global_step
然后在分析下build_acc_base,其主要目的是构建预测时acc的op。
【构建预测acc的op — build_acc_base】
self.prediction = tf.cast(tf.argmax(self.logits, 1), tf.int32)
self.correct_prediction = tf.equal(self.prediction, labels)
self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
self.accuracy_top_5 = tf.reduce_mean(tf.cast(tf.nn.in_top_k(predictions=self.logits, targets=labels, k=5), tf.float32))
比较简单,也都是通用操作。
这样test/val模式就分析完了,接下来分析下训练的专门的函数:build_model_train和FineTuneNASNet。
【FineTuneNASNet】
model_path = self.model_path
exclude = ['final_layer','aux_7']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
if is_training == True:
init_fn = slim.assign_from_checkpoint_fn(model_path, variables_to_restore)
else:
init_fn = None
tuning_variables = []
for v in exclude:
tuning_variables += slim.get_variables(v)
return init_fn, tuning_variables
这一步里我们拿到最后两层进行微调,最后两层可以通过执行tf.global_variables()把结点打出来找到最后两层的变量名,存为exclude。来通过slim.assign_from_checkpoint_fn构造加载模型的函数init_fn,传参exclude去除掉对应不需要的结点变量;把exclude里的结点通过slim.get_variables存到tuning_variables里,留待后用。
【build_model_train】
先和验证/测试中的步骤一样,先拿到网络预测的op,根据上面的代码分析,我们只需调用MyNASNet拿到logits即可,同时可以得到endpoints和global_step。
然后调用FineTuneNASNet得到微调相关的初始函数self.init_fn和微调变量self.tuning_variables。接下来就是常规操作:
- 定义loss函数
tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=self.logits)
loss = tf.losses.get_total_loss()
- 定义学习率(优化相关参数)
learning_rate1 = tf.train.exponential_decay(learning_rate=learning_rate1, global_step=self.global_step, decay_steps=100,decay_rate=0.5)
learning_rate2 = tf.train.exponential_decay(learning_rate=learning_rate2, global_step=self.global_step, decay_steps=100,decay_rate=0.2)
3.构建使用优化器对loss op进行优化的op
last_optimizer = tf.train.AdamOptimizer(learning_rate1)
full_optimizer = tf.train.AdamOptimizer(learning_rate2)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.last_train_op = last_optimizer.minimize(loss, self.global_step, var_list=self.tuning_variables)
self.full_train_op = full_optimizer.minimize(loss, self.global_step)
其中tf.control_dependencies()函数是用来控制计算流图的,也就是给图中的某些计算指定顺序。关于tf.GraphKeys.UPDATE_OPS,这是一个tensorflow的计算图中内置的一个集合,其中会保存一些需要在训练操作之前完成的操作,并配tf.control_dependencies函数使用。这样的用法是为了确保update_op在train_op之前执行,保证bn时moving_mean, moving_var更新到对应的值。(但是其他代码训练时好像没有显示的加上这句,一个困惑:什么时候需要单独加上的呢?)
4.调用构建acc的op :build_acc_base
5.设置summary相关
model.py到这里就分析完了,基本上就是做了一些准备工作,定义一些op,下面我们看训练和测试的代码。
【test.py】
- 建立model里定义模型的实例,然后调用其build_model方法。
mymode = MyNASNetModel()
mymode.build_model('test', test_dir)
2.运行tf.Session(),以执行对应的op
- 加载已存在的模型
- 调用对应的check_accuracy / check_sex 跑结果
with tf.Session() as sess:
mymode.load_cpk(mymode.global_step, sess, 1, mymode.saver, mymode.save_path)
val_acc = check_accuracy(sess)
print('Val acc: %f\n' % val_acc)
image_dir = 'https://gimg2.baidu.com/image_search/src=http%3A%2F%2Finews.gtimg.com%2Fnewsapp_bt%2F0%2F11580927704%2F641.jpg&refer=http%3A%2F%2Finews.gtimg.com&app=2002&size=f9999,10000&q=a80&n=0&g=0n&fmt=jpeg?sec=1627097985&t=e1b51b567591ce7a193f1c13a4e4a504'
check_sex(image_dir, sess)
check_accuracy
1.sess.run([mymode.correct_prediction, mymode.accuracy, mymode.logits])
根据预测的正确数量计算准确率即可
check_sex
- 准备图片
img = Image.open(imgdir)
if "RGB" != img.mode:
img = img.convert("RGB")
img = np.asarray(img.resize((image_size, image_size)), dtype=np.float32).reshape(1, image_size, image_size, 3)
img = 2 * (img / 255.0) - 1.0
- sess.run logits的op,把图片喂给它
prediction = sess.run(mymode.logits, {mymode.images: img})
pre = prediction.argmax()
【train.py】
- 建立model里定义模型的实例,然后调用其build_model方法。
mymode = MyNASNetModel('./nasnet-a_mobile_04_10_2017/model.ckpt')
mymode.build_model('train', val_dir, train_dir, batch_size, learning_rate1, learning_rate2)
2.运行tf.Session(),以执行对应的op
- 加载已存在的模型
下面开始进行两次训练,一次微调,一次全部训练。
4.微调:如果步数为0则进入微调模式,执行model里定义的init_fn方法,加载预训练模型,去掉最后两层参数。写个for循环迭代num_epochs1(自定义变量)次,每次循环初始化Dataset对象,sess.run(mymode.train_init_op);循环里面再写个while True循环反复读取dataset数据,报错即一次全部数据迭代完成,存储为新模型。
for epoch in range(num_epochs1):
print('Starting1 epoch %d / %d' % (epoch+1, num_epochs1))
sess.run(mymode.train_init_op)
while True:
try:
step += 1
acc, accuracy_top_5, summary, _ = sess.run([mymode.accuracy, mymode.accuracy_top_5, mymode.merged, mymode.last_train_op])
if step % 100 == 0:
print(f'step: {step} train1 accuracy: {acc}, {accuracy_top_5}')
except tf.errors.OutOfRangeError:
print("train1:", epoch, "ok")
mymode.saver.save(sess, mymode.save_path+"/mynasnet.cpkt", global_step=mymode.global_step.eval())
break
sess.run(mymode.step_init)
print(">>>")
- 全局调试类似,不过迭代次数比微调要多,代码里微调epoch为20,全局调试为200,sess.run的op除了mymode.last_train_op改成了mymode.full_train_op,其他不变。