微调预训练的流程为
# 网络前向
with slim.arg_scope(mobilenet_v1_arg_scope(weight_decay=WEIGHT_DECAY)):
# with tf.name_scope("net"):
logits, _ = mobilenet_v1(images, num_classes=NUM_CLASS, is_training=is_training,
depth_multiplier=0.25,
dropout_keep_prob=KEEP_PROB)
# 对于mobileNet来说要重新赋值depth_multiplier,这个数字模型名称有
# 读入网络
# 需要restore 的变量要定义在新加的图前面
# assert (os.path.isfile(MODEL_PATH))
variables_to_restore = slim.get_variables_to_restore(exclude=['MobilenetV1/Logits/Conv2d_1c_1x1']) # 最后一层不要restore
init_fn = slim.assign_from_checkpoint_fn(MODEL_PATH, variables_to_restore)
# 初始化没有导入的变量
logits_variables1 = slim.get_variables('MobilenetV1/Logits/Conv2d_1c_1x1')
logits_init1 = tf.variables_initializer(logits_variables1)
logits_init_list = [logits_init1]
global_step = slim.create_global_step()
with tf.name_scope("train_op"):
# fc8_optimizer = tf.train.GradientDescentOptimizer(BASE_LR1) 不能定义太多优化器,内存会爆掉
learning_rate = tf.train.exponential_decay(
PARAMS.params['model']['baseLR'],
global_step,
train_num / BATCH_SIZE,
PARAMS.params['model']['decayLR'])
# 冻结fc7 以前的层, var_list设置了需要更新的变量,不设置默认全图更新。
# fc8_train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step, var_list=fc8_variables)
# 全部训练
full_train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)
with tf.name_scope("accuracy"):
prediction = tf.argmax(logits, 1)
correct_prediction = tf.equal(prediction, tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar("accuracy", accuracy)
init_op = [tf.global_variables_initializer(), tf.local_variables_initializer()]
summary_merge = tf.summary.merge_all()
# 结束当前的计算图,使之成为只读
tf.get_default_graph().finalize()
with tf.Session() as sess:
# 先初始化网络
init_fn(sess) # load the pretrained weights
sess.run(logits_init_list) # initialize the new fc8 layer
# 初始化自定义的变量,如global_step
sess.run(init_op)
# tensorboard_writer = tf.summary.FileWriter(TENSORBOARD_PATH, sess.graph)
for epoch in range(NUM_EPOCHS):
print('Starting epoch %d / %d' % (epoch + 1, NUM_EPOCHS))
sess.run(train_iterator.initializer)
while True:
try:
train_batch_images, train_batch_labels \
= sess.run([train_images, train_labels])
_, train_loss, train_acc = sess.run([full_train_op, loss, accuracy],
feed_dict={is_training: True,
images: train_batch_images,
labels: train_batch_labels})
val_batch_images, val_batch_label = \
sess.run([val_images, val_labels])
val_loss, val_acc = sess.run([loss, accuracy],
feed_dict={is_training: False,
images: val_batch_images,
labels: val_batch_label})
step = sess.run(global_step)
print("global_step:{0}".format(step))
print("epoch:{0}, train loss:{1},train-acc:{2}".format(epoch, train_loss, train_acc))
print("epoch:{0}, val loss:{1},val-acc:{2}".format(epoch, val_loss, val_acc))
except tf.errors.OutOfRangeError:
break
# tensorboard_writer.close()