[tf] Scaffolds

image.png

scaffolds可以被传进tf.estimator和tf.MonitoredTrainingSession里面,用于训练和其他,在scaffolds里面可以指定各种各样的操作符,但是现在我们集中于init_fn这个操作符,因为其他操作符都是用于分布式设置的,我们一般是用不到的,init_fn在graph被建立之后以及sess.run被调用前时首次调用。他只会被调用一次,所以是一个很好的可以能自定义变量初始化的方式,一个小例子说明一下。

  • 原来是这么写的:
    # look for checkpoint
    model_path = tf.train.latest_checkpoint(path)
    initializer_fn = None

    if model_path:
        # only restore variables in the scope_name scope
        variables_to_restore = slim.get_variables_to_restore(include=[scope_name])
        # Create the saver which will be used to restore the variables.
        initializer_fn = slim.assign_from_checkpoint_fn(model_path, variables_to_restore)
    else:
        print("could not find the fine tune ckpt at {}".format(path))
        exit()

    def InitFn(scaffold,sess):
        initializer_fn(sess)
    return InitFn
  • 现在是这么写的:
# setup fine tune scaffold
scaffold = tf.train.Scaffold(init_op=None,
     init_fn=tools.fine_tune.init_weights(params["fine_tune_ckpt"]))

# create estimator training spec
return tf.estimator.EstimatorSpec(tf.estimator.ModeKeys.TRAIN,
                                loss=loss,
                                train_op=train_op,scaffold=scaffold)

你可能感兴趣的:([tf] Scaffolds)