Tensorflow:tf.gradient()用法以及参数stop_gradient理解

tf.gradient()

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name='gradients',
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None
)

ys : 类型是张量或者张量列表,类似于目标函数,需要被微分的函数
xs:类型是张量或者张量列表,需要求微分的对象。(上述即为:dys/dxs)
stop_gradients: 可选参数,类型是张量或者张量列表,不需要通过微分的对象(比较抽象,看完下面的例子)

用一个例子来帮助理解

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b])
with tf.Session() as sess:
    print(sess.run(g))
结果:[3.0, 1.0]

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[a])
with tf.Session() as sess:
    print(sess.run(g))
结果:[3.0, 1.0]

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[b])
with tf.Session() as sess:
    print(sess.run(g))
结果:[1.0, 1.0]  

可以看出,第一个参数ys是准备被微分的函数,第二个参数即xs填的是反向传播是需要求导的参数,第三个参数即stop_gradient,在反向传播时,如果填了参数b,那么a + b中a,b都是独立的,否则a + b= 3a(因为在本例中b = 2a)


如果觉得我有地方讲的不好的或者有错误的欢迎给我留言,如果对您有帮助,帮我点个赞哦~,感谢大家阅读

你可能感兴趣的:(Tensorflow)