tensorflow学习(二)——finetune预训练网络--以mobileNetV1为例

文章目录

      • 一、 流程
      • 二、 mobileNetV1预训练

一、 流程

微调预训练的流程为

  • 准备数据
  • 定义网络结构,除了网络外,还可以根据需要定义梯度下降算法,准确度计算等。
  • 找到网络的定义,设置需要导入的变量(重要,因为fine-tune最后分类层如果数据类别数目不同的话是一定要改的)
  • 设置需要梯度下降更新的变量(即可以选择冻结前面特征提取层、只训练分类层)

二、 mobileNetV1预训练

  • 下载预训练的权重并解压
    网络结构
    权重
    这里我下载的是mobilenet_v1_0.25_224,解压后图如下:网络的ckpt名字为“mobilenet_v1_0.25_224.ckpt”,0.25指的是depth_multiplier的值,定义网络的时候要设置。
    tensorflow学习(二)——finetune预训练网络--以mobileNetV1为例_第1张图片
  • 网络前向结构定义与导入变量设置
# 网络前向
with slim.arg_scope(mobilenet_v1_arg_scope(weight_decay=WEIGHT_DECAY)):
    # with tf.name_scope("net"):
    logits, _ = mobilenet_v1(images, num_classes=NUM_CLASS, is_training=is_training,
                             depth_multiplier=0.25,
                             dropout_keep_prob=KEEP_PROB)
    # 对于mobileNet来说要重新赋值depth_multiplier,这个数字模型名称有
# 读入网络
# 需要restore 的变量要定义在新加的图前面
# assert (os.path.isfile(MODEL_PATH))
variables_to_restore = slim.get_variables_to_restore(exclude=['MobilenetV1/Logits/Conv2d_1c_1x1'])  # 最后一层不要restore
init_fn = slim.assign_from_checkpoint_fn(MODEL_PATH, variables_to_restore)
  • 设置梯度下降算法,指定需要下降的变量
# 初始化没有导入的变量
   logits_variables1 = slim.get_variables('MobilenetV1/Logits/Conv2d_1c_1x1')
   logits_init1 = tf.variables_initializer(logits_variables1)
   logits_init_list = [logits_init1]
   global_step = slim.create_global_step()
   with tf.name_scope("train_op"):
       # fc8_optimizer = tf.train.GradientDescentOptimizer(BASE_LR1)  不能定义太多优化器,内存会爆掉
       learning_rate = tf.train.exponential_decay(
           PARAMS.params['model']['baseLR'],
           global_step,
           train_num / BATCH_SIZE,
           PARAMS.params['model']['decayLR'])
       # 冻结fc7 以前的层, var_list设置了需要更新的变量,不设置默认全图更新。
       # fc8_train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step, var_list=fc8_variables)
       # 全部训练
       full_train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)
   with tf.name_scope("accuracy"):
       prediction = tf.argmax(logits, 1)
       correct_prediction = tf.equal(prediction, tf.argmax(labels, 1))
       accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
       tf.summary.scalar("accuracy", accuracy)
   init_op = [tf.global_variables_initializer(), tf.local_variables_initializer()]
   summary_merge = tf.summary.merge_all()

   # 结束当前的计算图,使之成为只读
   tf.get_default_graph().finalize()
  • 训练
with tf.Session() as sess:
   # 先初始化网络
   init_fn(sess)  # load the pretrained weights
   sess.run(logits_init_list)  # initialize the new fc8 layer
   # 初始化自定义的变量,如global_step
   sess.run(init_op)
   # tensorboard_writer = tf.summary.FileWriter(TENSORBOARD_PATH, sess.graph)
   for epoch in range(NUM_EPOCHS):
       print('Starting epoch %d / %d' % (epoch + 1, NUM_EPOCHS))
       sess.run(train_iterator.initializer)
       while True:
           try:
               train_batch_images, train_batch_labels \
                   = sess.run([train_images, train_labels])
               _, train_loss, train_acc = sess.run([full_train_op, loss, accuracy],
                                                    feed_dict={is_training: True,
                                                           images: train_batch_images,
                                                           labels: train_batch_labels})
               val_batch_images, val_batch_label = \
                   sess.run([val_images, val_labels])
               val_loss, val_acc = sess.run([loss, accuracy],
                                            feed_dict={is_training: False,
                                                       images: val_batch_images,
                                                       labels: val_batch_label})
               step = sess.run(global_step)
               print("global_step:{0}".format(step))
               print("epoch:{0}, train loss:{1},train-acc:{2}".format(epoch, train_loss, train_acc))
               print("epoch:{0}, val loss:{1},val-acc:{2}".format(epoch, val_loss, val_acc))
           except tf.errors.OutOfRangeError:
               break
   # tensorboard_writer.close()

你可能感兴趣的:(tensorflow)