TensorFlow2.0 中 GradientTape()函数详解

TensorFlow2.0 中 GradientTape()函数详解

一、函数

tf.GradientTape(
    persistent=False, watch_accessed_variables=True
)

二、作用

tensorflow 提供tf.GradientTape api来实现自动求导功能。只要在tf.GradientTape()上下文中执行的操作,都会被记录与“tape”中,然后tensorflow使用反向自动微分来计算相关操作的梯度。

可训练变量(由tf.Variable或创建tf.compat.v1.get_variabletrainable=True在两种情况下均为默认值)将被自动监视。通过watch在此上下文管理器上调用方法,可以手动监视张量。

三、参数

  • persistent:布尔值,用于控制是否创建持久渐变磁带。默认情况下为False,这意味着最多可以在此对象上对gradient()方法进行一次调用。
  • watch_accessed_variables:布尔值,控制watch在磁带处于活动状态时磁带是否将自动访问任何(可训练的)变量。默认值为True,可以从磁带中读取可训练的磁带得出的任何结果中请求梯度Variable。如果为False,则用户必须明确要求他们要从中请求渐变的watch任何Variable

四、举个栗子

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  y = x * x
dy_dx = g.gradient(y, x) # 计算结果为 6.0

with的用法可参见另外一篇博客详解。

watch的作用是确保tensor类型的数据能被梯度带检测到。

再举个栗子,可以嵌套GradientTapes以计算高阶导数:

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  with tf.GradientTape() as gg:
    gg.watch(x)
    y = x * x
  dy_dx = gg.gradient(y, x)     # 计算结果为 6.0
d2y_dx2 = g.gradient(dy_dx, x)  # 计算结果为 2.0

高能重点来了!!!

默认情况下,只要调用GradientTape.gradient()方法,就会释放GradientTape拥有的资源。**也就是说只能用一次!!!**要在同一计算上计算多个梯度,请创建一个持久梯度带。当梯度带回收,释放资源时,这允许多次调用gradient()方法。例如:

x = tf.constant(3.0)
with tf.GradientTape(persistent=True) as g:
  g.watch(x)
  y = x * x
  z = y * y
dz_dx = g.gradient(z, x)  # 108.0 (4*x^3 at x = 3)
dy_dx = g.gradient(y, x)  # 6.0
del g  # Drop the reference to the tape

另外,默认情况下,GradientTape将自动监视在上下文中访问的所有可训练变量。如果要对监视哪些变量进行精细控制,可以通过传递watch_accessed_variables=False给梯度带构造函数来禁用自动跟踪 :

with tf.GradientTape(watch_accessed_variables=False) as tape:
  tape.watch(variable_a)
  y = variable_a ** 2  # 梯度对于`variable_a`将是可用的.
  z = variable_b ** 3  # 梯度对于`variable_b`将是不可用的,因为`variable_b`没有被监视,即watch()。

你可能感兴趣的:(TensorFlow归纳)