pointnet学习(五)train函数,第五、六句

第五句
batch = tf.Variable(0)
声明一个tensor常量,值为0,主要是供第六句get_bn_decay使用
作者给的解释是,这个batch用来设置glob_step。
第六句
bn_decay = get_bn_decay(batch),这一句用来设置train的过程中学习率的衰减系数的。

具体实现如下:
 

def get_bn_decay(batch):
    bn_momentum = tf.train.exponential_decay(
                      BN_INIT_DECAY,
                      batch*BATCH_SIZE,
                      BN_DECAY_DECAY_STEP,
                      BN_DECAY_DECAY_RATE,
                      staircase=True)
    bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum)
    return bn_decay

通过调用 tf.train.exponential_decay,以及tf.minimum来构建出bn_decay这个tensor。

tf.train.exponential_decay

tf.compat.v1.train.exponential_decay(
    learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None
)

根据官方的解释,在train的过程中,只定一个learningrate显然不够智能,也不够灵活,在train的过程中会出现各种问题,所以建议在train过程中,让learningrate按照某种规律进行衰减, exponential_decay函数则提供了一个指数衰减的函数用来更新learningrate。计算公式如下

decayed_learning_rate = learning_rate *
                        decay_rate ^ (global_step / decay_steps)

 从code可以看出,learning_rate为我们设置的初始learning_rate,我们的代码里面是BN_INIT_DECAY=0.5。global_step可以通过设置一个常量tensor来指定,我们这里利用了batch*BATCH_SIZE来指定,BATCH_SIZE默认为64,batch为0.(这里比较疑惑的是,batch这个tensor常量为0,这每次global_step都是0,岂不是不起作用了,每次learning_rate跟之前都是一样的)decay_steps代码里面设置的是DECAY_STEP=200000, decay_rate代码里面设置的是DECAY_RATE=0.7,是我们的learning_rate的衰减参数;由此可见global_step是会变化的,每次trainprocess都会变化,第一步是0,后续进行更新。最后一个参数是staircase,这里staircase设置为true那么我的global_step / decay_steps则取整数,这样我们的learning_rate衰减的指数值则为楼梯状的。

官方例子如下:

...
global_step = tf.Variable(0, trainable=False)
starter_learning_rate = 0.1
learning_rate = tf.compat.v1.train.exponential_decay(starter_learning_rate,
global_step,
                                           100000, 0.96, staircase=True)
# Passing global_step to minimize() will increment it at each step.
learning_step = (
    tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
    .minimize(...my loss..., global_step=global_step)
)
tf.minimum官方例子也有个minimize

tf.minimum的功能是取较小的值,而代码里面用了bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum),其中BN_DECAY_CLIP=0.99,也就是说,作者这里的learningrate是逐渐增加的。从0.5开始,最大到0.99.

 

你可能感兴趣的:(pointnet,tensorflow)