目录
本文主要粗浅的讲解tf.function的相关内容,主要分为两块内容,一是tf.function的作用(为什么要有tf.function)。二是在使用tf.function时需要注意的点(tf.function的一些特性)
tf.function的作用
通过对比 tf1.x 与 tf2.x eager 与 tf2.x tf.functon进行说明tf.function的作用
tf.1.x
首先从tf1.x的代码风格说起,在tf1.x中,我们需要自行创建一个graph,再把其加载进tf.Session中,最后使用tf.Session.run方法开始计算(下面的例子很清晰的说明了tf1.x中Graph的创建与计算流程)
with g.as_default():
a = tf.constant([[10,10],[11.,1.]])
x = tf.constant([[1.,0.],[0.,1.]])
y = tf.matmul(a, x) + b
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
print(sess.run(y))
tf.2.x eager
但是在tf2.x中,默认的Eager execution代码风格发生了很大的变化,具体表现在:1)移除了关于graph的定义 ;2)移除了session执行;3)移除了变量初始化;4)移除了通过scope实现variable sharing;5)移除tf.control_dependencies以执行没有依赖关系连接的顺序操作。如下方代码所示:
a = tf.constant([[10,10],[11.,1.]])
x = tf.constant([[1.,0.],[0.,1.]])
y = tf.matmul(a, x) + b
print(y.numpy())
tf2.x tf.functon
tf.function本质上就是一个函数修饰器,它能够帮助将用户定义的python风格的函数代码转化成高效的tensorflow计算图(可以理解为:之前tf1.x中graph需要自己定义,现在tf.function能够帮助一起定义)。转换的这个过程称为AutoGraph。如下方代码所示:
@tf.function
def f():
a = tf.constant([[10,10],[11.,1.]])
x = tf.constant([[1.,0.],[0.,1.]])
y = tf.matmul(a, x)
return y
f()
我们可以先来简单的了解下tf.function修饰器具体做了哪些工作?
该函数被执行和跟踪(tracing). Eager模型被关闭禁用,所有的tf.api方法都被当做tf.Operation来构建Graph,并产生tf.Tensor output
AutoGraph被用于检测代码中是否存在能够被转换为graph的等价操作(while -> tf.while, for -> tf.for, if -> tf.cond, assert -> tf.assert, …)
经过上述两步后,为了保证graph中语句的执行顺序,tf.control_dependencies被自动加入到代码中,保证第i行执行完后执行第i+1行
创建tf.Graph对象,并根据函数名称和输入参数,创建与Graph相关联的唯一ID,并被加载进一个map中:map[id] = graph
对于任何一个函数的调用,都会重复调用cache中的graph
小结
通过上述的 tf1.x与tf2.x eager与tf2.x tf.functon的对比,我们已经大概知道了tf.function的具体作用(在保证代码易读以及易写的前提下提升执行效率)。接下来将简单介绍在使用tf.function时需要注意的一些点(帮助我们高效的使用tf.function,提升执行效率)
使用tf.function时需要注意的点
tf.variable只会被创建一次
首先来看下面一段代码:
@tf.function
def f():
a = tf.constant([[10,10],[11.,1.]])
x = tf.constant([[1.,0.],[0.,1.]])
b = tf.Variable(12.)
y = tf.matmul(a, x) + b
return y
f()
为啥会报错呢?是因为tf.function可能会对一段python代码进行多次执行来进行graph的构建,在多次的执行过程中,Variable被创建了多次,而tensorflow 2.x文档中明确指出State (like tf.Variable objects) are only created the first time the function f is called.。所以就报错了。
因此我们在构建被tf.function修饰的函数时,一定要记得保证每一个tf.Variable变量只被创建一次,否则就有可能报错。那么关于上述报错代码,正确的写法应该是怎样呢?(如下)
class F():
def init(self):
self._b = None
@tf.function
def call(self):
a = tf.constant([[10, 10], [11., 1.]])
x = tf.constant([[1., 0.], [0., 1.]])
if self._b is None:
self._b = tf.Variable(12.)
y = tf.matmul(a, x) + self._b
print("PRINT: ", y)
tf.print("TF-PRINT: ", y)
return y
f = F()
f()
参数类型与参数值决定是否创建新的graph
tf.function的最终会创建一个具有唯一标识ID的graph,如下代码所示:
@tf.function
def f(x):
return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)
通常被修饰的函数的参数类型与参数值会决定是否创建新的graph。我们可以分为传入参数类型是tf.Tensor与原始python类型两种情况讨论。
当传入参数类型是tf.Tensor时
当传入不同类型、相同shape、相同参数值的tf.Tensor值时
@tf.function
def f(x):
return x + 1
vector = tf.constant([1], dtype=tf.int32)
matrix = tf.constant([1.0], dtype=tf.float32)
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
当传入相同类型、不同shape、相同参数值的tf.Tensor值时
@tf.function
def f(x):
return x + 1
vector = tf.constant([1.0], dtype=tf.float32)
matrix = tf.constant([[1.0]], dtype=tf.float32)
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
当传入相同类型、相同shape、不同参数值的tf.Tensor值时
@tf.function
def f(x):
return x + 1
vector = tf.constant([1.0], dtype=tf.float32)
matrix = tf.constant([3.0], dtype=tf.float32)
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
因此当传入参数类型是tf.Tensor时,我们可以得出结论:类型与shape决定了新建还是复用graph
传入不同类型、相同shape、相同参数值时
@tf.function
def f(x):
return tf.abs(x)
f1 = f.get_concrete_function(1)
f2 = f.get_concrete_function(1.0)
f1 is f2
传入相同类型、不同shape、相同参数值时
@tf.function
def f(x):
return tf.abs(x)
f1 = f.get_concrete_function([1])
f2 = f.get_concrete_function(1)
f1 is f2
传入相同类型、相同shape、不同参数值时
@tf.function
def f(x):
return tf.abs(x)
f1 = f.get_concrete_function(1)
f2 = f.get_concrete_function(2)
f1 is f2
因此当传入参数类型是原始python类型时,我们可以得出结论,值与shape决定了新建还是复用graph
建议:在使用tf.function时,要本着尽可能复用graph的原则,因为新建graph耗时且耗资源
在被tf.function修饰的函数体中,代码可以大概分为两类,其一为python风格的代码(比如print()),其二为tensorflow风格的代码(比如tf.print)。对于python风格的语句,只有在新建graph的过程才会被执行一次(复用graph时则不会被执行)。如下代码所示:
@tf.function
def f(a, b):
print(‘this runs at trace time; a is’, a, ‘and b is’, b)
return b
f(1, tf.constant(1))
f(1, tf.constant(2))
f(2, tf.constant(1))
f(2, tf.constant(2))
总结
本文主要从tf.function的作用以及tf.function的一些特性两个方面进行简单讲解。事实上关于使用tf.function进行Autograph的创建是一个比较的复杂的过程,还包括tf.function具体是如何修饰函数的(过程细节)、以及如何高效的利用tf.function等多个需要探讨的点