【Tensorflow】Tensorflow 自定义梯度

目录

前言   

自定义梯度

说明

gradient_override_map的使用

多输入与多输出op

利用stop_gradient

参考


【fishing-pan:https://blog.csdn.net/u013921430 转载请注明出处】

前言   

       在Tensorflow中大部分的op都提供了梯度计算方式,可以直接使用,但是有少部分op并未提供。此时,就需要使用者自己定义梯度了。我查阅了一些资料,发现主要有两种梯度的计算方式。

自定义梯度

       这种方法主要是参考Tensorflow对梯度的定义;使用tf.RegisterGradient()函数结合tf.get_default_graph().gradient_override_map()函数,前者注册梯度,后者赋予梯度;使用方法如下;例如,使用自定义的梯度实现 tf.multiply() 函数的梯度。

@tf.RegisterGradient('mulGrad')
def mulGrad(op, grad):
    grad = grad * op.inputs[1]

    return [grad, None]

W = tf.get_variable('w', shape=[1], dtype=tf.float32, initializer=tf.truncated_normal_initializer(0.0, 0.02))
I = tf.constant(10.0, dtype=tf.float32)

with tf.get_default_graph().gradient_override_map({'Mul': 'mulGrad'}):
    Y = tf.multiply(W, I, name='Multiply')

说明

       使用这种方法,有两个地方需要说明;

gradient_override_map的使用

        gradient_override_map()的输入中,冒号前是op的类型而不是name(为了区分,我特地给上面代码中的乘法取了跟type不一样的name)。但是很多时候我们无法的得知op的type,此时我们需要首先通过get_operation_by_name()来获取type。例如上面例子中的获取乘法的type的方法是;

graph = tf.get_default_graph()
print(graph.get_operation_by_name('Multiply'))

        最终的输出的前几行是如下,其中op就表示操作的type。

name: "Multiply"
op: "Mul"
input: "w/read"
input: "Const"

多输入与多输出op

       当op有多个输入输出时,梯度函数的定义是不一样的。由于梯度反向传播与前向传播想法,当op有多个输入时,就需要返回多个梯度,上面的例子中就是op有两个输入的情况,梯度函数应该返回两个输出。同样的道理,当op有多个输出时,在梯度回传时,会传入多个梯度,此时函数的形参需要多个梯度,例如下面的例子就是两个输入,三个输出的op的梯度重定义。

@tf.RegisterGradient('funcGrad')
def funcGrad(op, grad1,grad2,grad3):
    grad = grad1*grad2 * op.inputs[1]

    return grad,None

利用stop_gradient

       当op的梯度回传比较简单时,可以使用tf.stop_gradient()实现梯度回传,例如;

output = input + tf.stop_gradient(func(input) –input)
#正向传递,式子相当于
#output = func(input) 
#反向传播,式子相当于
#output = input

        上面的例子中,当前向传播时,传递op的输出,反向传播时跳过op梯度回传直接传递给input;当梯度需要经过复杂的变换时,可以通式表示为;

output = gard_func(input) + tf.stop_gradient(func(input) –gard_func(input) )

参考

  1. https://blog.csdn.net/LoseInVain/article/details/83108001
  2. https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/python/ops/array_grad.py

你可能感兴趣的:(TensorFlow)