GradientTape
是eager
模式下计算梯度用的,而eager
模式(eager模式的具体介绍请参考文末链接)是TensorFlow 2.0
的默认模式。 通过GradientTape
可以对损失的计算过程、计算方式进行深度定制,即所谓的Custom training
, 而不仅仅是通过model.train
这样过于高级(傻白甜)的API的方式进行训练。这在很多场合下是非常有用的。
tf.GradientTape
定义在tensorflow/python/eager/backprop.py
文件中。下面通过几个Demo
来逐步深入地学习GradientTape
的用法。
计算 z = x 2 z=x^2 z=x2关于x的梯度, x 0 x_{0} x0 = 3.0, 结果为:6
# -*- coding: utf-8 -*-
import tensorflow as tf
from functools import partial
x = tf.constant(3.0, dtype=tf.float32)
y = tf.constant(2.0, dtype=tf.float32)
with tf.GradientTape() as g:
g.watch(x)
z = x * x
dz_x = g.gradient(z, x)
tf.print(dz_x)
z = x y 2 z=xy^2 z=xy2,计算 z/x, z 2 z^2 z2/xy 的导数,结果分别为:4, 4,
with tf.GradientTape() as g:
g.watch(y)
with tf.GradientTape() as gg:
gg.watch(x)
z = x*y*y
dz_dx = gg.gradient(z, x)
tf.print(dz_dx)
dzx_y = g.gradient(dz_dx, y)
tf.print(dzx_y)
z = x 2 z=x^2 z=x2,计算 z/x, z 2 z^2 z2/xx 的导数,结果分别为:6, 2,
with tf.GradientTape() as g:
g.watch(x)
with tf.GradientTape() as gg:
gg.watch(x)
y = x * x
dy_dx = gg.gradient(y, x)
print(dy_dx)
d2y_dx2 = g.gradient(dy_dx, x)
print(d2y_dx2)
由于tf.GradientTape()
自带的参数persistent
默认=False
,因此只能调用一次tape.gradient
来进行求导, 怎么办呢? 两种方法:
(1)定义两个tape
,然后每个tape
分别对其中一个目标求导,如下:
with tf.GradientTape() as g, tf.GradientTape() as f:
g.watch(x)
f.watch(y)
z = x * x + y
m = y + y*y + x
dz_x = g.gradient(z, x)
dm_y = f.gradient(m, y)
print(dz_x, dm_y)
(1)定义一个tape,并设置persistent=True
,(否则会出现类似RuntimeError: GradientTape.gradient can only be called once on non-persistent tapes
的错误)如下:
with tf.GradientTape(persistent=True) as gf:
gf.watch([x, y])
z = x * x + y
m = y + y*y + x
dz_x = gf.gradient(z, x)
dm_y = gf.gradient(m, y)
print(dz_x, dm_y)
del gf
记得最后要将tape删除: del gf
。
本人的实际代码环境比较复杂,仅通过一下的简化代码说明问题。
注:model和model_d都为module.
model = build_model()
model_d = build_model2()
with tf.GradientTape() as ta:
embed= model(input)
pred = model_d(embed)
gp = gradient_penality(model_d, pred, tgt) # call function
grad_2 = ta.gradient(gp, model_d.trainable_parameters)
def gradient_penality(model_d, pred, tgt):
# a fucntiuon execute gradient penality based on inputs.
with tf.GradientTape(0 as tb:
tb.watch(intermediate)
intermediate = f(pred, tgt) # f: function
output = model_d(intermediate)
grad = tb.gradient(output, intermediate)
return f1(grad) # f1: function
如上所示,本人遇到的问题是:grad_2
输出结果总是为:None
, 从而导致网络第二点此处更新时参数全部为NaN
。该问题是本人在尝试将Improved Training of Wasserstein GANs这篇文章的代码由tf1转为tf2的过程中出现的。
分析: 假设model_d
本身是变量的话那就简单了:按照Demo 2
:GradientTape
的嵌套处理即可。然而此处,model_d
是模型,其本身带有可训练的参数:model_d.trainable_parameters
,
解决方案:
将 gp = gradient_penality(model_d
, pred, tgt) 替换为:
gp = gradient_penality(partial(model_d, training=True)
, pred, tgt)即可。
关于functional.partial()的用法:
functools.partial
是偏函数,它的本质就是基于一个函数创建一个新的可调用对象, 把原函数的某些参数固定。 使用这个函数可以把接受一个或多个参数的函数改编成需要回调的API, 这样参数更少。如:functools.partial(api_export, p1)的作用是把函数api_export的第一个参数固定为p1,functools.partial(api_export, p=p1)的作用是把函数api_export的参数p固定为p1,api_export是实现了__call__()函数的类.
tf_export=unctools.partial(api_export, api_name=TENSORFLOW_API_NAME)
的写法等效于:
funcC = api_export(api_name=TENSORFLOW_API_NAME)
// 会调用_init_构造函数
tf_export = funcC
, //函数名称tf_export, 调用时执行_call_
1.WGAN with gradient penality 的相关实现https://github.com/igul222/improved_wgan_training
2.functional.partial()的相关资料https://blog.csdn.net/menghaocheng/article/details/83479754