挑选出tensor中等于0的索引_[TensorFlow2.0文档翻译] 使用tf.function提升性能

挑选出tensor中等于0的索引_[TensorFlow2.0文档翻译] 使用tf.function提升性能_第1张图片

TensorFlow2.0 默认动态图机制(eager execution)。用户接口直观且灵活(运行一次性的操作很简单且快) 但是可能会牺牲性能和部署的灵活性。

为了了解部署模型的性能, 使用 tf.function 从程序中分离出拓扑图,这得益于AutoGraph和pyton中一些高质量的代码。 但是依然存在一些缺陷需要谨慎对待:

主要的打包和建议是:

  • 不要依赖python中的副作用对象例如可变对象和追加list
  • tf.function和TensorFlow操作的兼容性 要比和NumPy操作以及原生的Python操作要好
  • 如果存在疑虑的时候, 无脑使用for x in y 这种写法。

定义一个辅助方法来处理异常

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

基础

定义一个tf.function就像TensorFlow核心操作一样:可以直接执行; 可以在图中使用;可以计算梯度等等

@tf.function
def add(a, b):
    return a + b

print(add(tf.ones([2, 2]), tf.ones([2, 2])))

# 输出
tf.Tensor(
[[2. 2.]
 [2. 2.]], shape=(2, 2), dtype=float32)

也可以在其他funciton中使用

@tf.function
def dense_layer(x, w, b):
    return add(tf.matmul(x, w), b)

print(dense_layer(tf.ones([2, 3]), tf.ones([3, 2]), tf.ones([2, 2])))

# 输出
tf.Tensor(
[[4. 4.]
 [4. 4.]], shape=(2, 2), dtype=float32)

追踪和多态

python的动态类型特性意味着你可以不同的参数类型调用functions, 而且Python可以做同样的事情(类似于C++、Java中的重载)。另一方面TensorFlow的计算图要求静态类型和维度(shape dimensions),tf.function在需要的时候追踪funcitons来生成正确的图,用这种方式平衡前述两者的差异. 许多tf.function细节上都基于这种实现(多态+追踪)。

可以调用一个函数使用不同的参数类型(类似于重载?)

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

# 输出
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

为了控制这种追踪行为,使用以下技术:

  • 创建一个新的tf.function, 和tf.function对象分开保证步共享相同的追踪
  • 使用get_concrete_function方法得到特定的追踪
  • 当调用tf.function只追踪一次的时候使用input_signature
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
print("Using a concrete trace with incompatible types will throw an error")
with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
# 输出
Obtaining concrete trace
Tracing with Tensor("a:0", dtype=string)
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
Using a concrete trace with incompatible types will throw an error
Caught expected exception 
  :

Traceback (most recent call last):
  File "", line 8, in assert_raises
    yield
  File "", line 8, in 
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_87 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_87]
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))
# 输出
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  :

Traceback (most recent call last):
  File "", line 8, in assert_raises
    yield
  File "", line 9, in 
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

什么时候使用追踪?

使用追踪产生的构造函数会放在多态tf.function的缓存中,缓存中的值是fuction函数中参数args和kwargs组成的key的组合。这些key(缓存的键)在tf.Tensor中是一个张量(包含维度和type), 在原生Python中是他们的值。对于所有的Python类型缓存中的key取决于对象的id(),所以这个方法可以独立的追溯到每一个实例。 将来TensorFlow可能会为Python对象添加更多复杂的缓存以便更安全的转换张量。查看 Concrete functions

Python或是张量对象?

通常Python参数用于控制超参数和计算图结构。例如num_layers=10ortraining=Trueornonlinearity='relu' , 如果Python参数改变了,就必须重新生成计算图。

但是,也不是所有的Python参数都来控制计算图结构。这时Python参数的改变会引起不必要的追踪。举个例子,在训练循环中,自动图会自动展开,尽管有多个追踪,但是生成的图是一样的,这样就有些低效

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = {}".format(num_steps))
  for _ in tf.range(num_steps):
    train_one_step()

train(num_steps=10)
train(num_steps=20)

# 输出
Tracing with num_steps = 10
Tracing with num_steps = 20

解决方案是把参数转换成张量,这样也不会影响图的的生成

train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
# 输出
Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32)

tf.function的副作用

一般来说,Python的副作用(打印或改变对象)只发生在追踪过程中。那如何稳定可靠的触发tf.function的副作用?

约定俗成的规则是只在debug中使用Python的副作用,但是TensorFlow的操作像tf.Variable.assign,tf.print, andtf.summar可以很好的保证代码可追踪且TensorFlow运行时可调用。使用函数式风格可以输出最好的结果

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)  
# 输出
Traced with 1 # 调用两次f(1) 只执行了一次python的打印, 执行了两次tf.print
Executed with 1
Executed with 1
Traced with 2
Executed with 2

如果你想在每次batch结束的时候调用tf.function、tf.pyfunction的时候调用Python代码。 这里有一个缺点。tf.py_function不够轻便且不是即时有效,而且在分布式环境(多卡GPU、TPU)里也容易出问题,同时tf.py_function必须链接到不同的图,它把所有的输入输出都转换为张量。

external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
assert len(external_list) == 3
# .numpy() call required because py_function casts 1 to tf.constant(1)
assert external_list[0].numpy() == 1

# 输出
Python side effect
Python side effect
Python side effect # 调用三次 执行了三次python的print

注意Python的状态

许多Python的特性,比方说说生成器和迭代器.依赖于Python运行时追踪的状态。一般的,即使在Eager模式下正常工作,因为追踪的存在依然会在tf.function内部发生很多意想不到的事情。

例如:

external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)

# 输出
Value of external_var: 0
Value of external_var: 0
Value of external_var: 0 # 这里迭代器没有生效

如果迭代器的生成和消费都在tf.function中执行, 那就可以正常工作。 但是,整个迭代器都被追踪后, 就会被引向一个巨大的图,这可能正是你想要的,但是如果在训练占用大量内存的pyhton list时,就会生成一个非常大的图,tf.function也不大可能会加速。

如果想要便利所有的Python迭代器里的数据, 比较可靠的方法是把数据打包到tf.data.Dataset对象里然后无脑使用for x in y 遍历数据, 当y是一个tensor或tf.data.Dataset的时候AutoGraph对这种操作做了支持

def measure_graph_size(f, *args):
    g = f.get_concrete_function(*args).graph
    print("{}({}) contains {} nodes in its graph".format(
        f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
    loss = tf.constant(0)
    for x, y in dataset:
        loss += tf.abs(y - x)
    return loss

small_data = [(1, 1)] * 2
big_data = [(1, 1)] * 10

measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))

# 输出
train([(1, 1), (1, 1)]) contains 8 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(, ), types: (tf.int32, tf.int32)>) contains 9 nodes in its graph
train(, ), types: (tf.int32, tf.int32)>) contains 9 nodes in its graph

当打包Python/Numpy数据到Dataset的时候,注意tf.data.Dataset.from_generator和tf.data.Dataset.from_tensors 通过tf.py_function获取时会保持之前的数据,这会是为了提高性能, 而剩下的会打包成一个数据的备份 在计算图中作为一个tf.constant()节点 这样是为了节省内容。

使用TFRecordDataset/CsvDataset/...可以更高效的消费数据,TensorFlow可以异步加载和获取数据,无需转换成原生Python类型。

自动控制依赖

使用数据流图进行函数式编程有一个很吸引人的特性,函数在运行时可以提供更多关于执行代码的信息。

举个例子,编写多处需要读写相同变量的代码时, 数据流图可能不会对原始的操作进行编码。在tf.function中 执行顺序按照Pyton代码的声明顺序解决歧义。 这样tf.function里的状态会被复制一份到Eager模式下。

这意味着不用手动添加依赖控制,tf.function足够智能去添加最小最需要最高效的依赖来保证程序正确执行

# Automatic control dependencies

a = tf.Variable(1.0)
b = tf.Variable(2.0)

@tf.function
def f(x, y):
  a.assign(y * b)
  b.assign_add(x * a)
  return a + b

f(1.0, 2.0)  # 10.0

# 输出

参数

我们可以方便的使用tf.function按照预想的执行顺序来创建参数。但是有一个很重要的警告——同样的参数代码在eager模式和graph模式中执行不同。

特别是在创建一个参数后并且每次调用时。由于追踪机制,在每次调用时tf.function可以重复使用相同的参数, 但是eager模式会重新创建一个参数,为了防止这种错误, tf.function会跑出一个错误来发现这种异常

@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)

# 输出
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Caught expected exception 
  :

Traceback (most recent call last):
  File "", line 8, in assert_raises
    yield
  File "", line 8, in 
    f(1.0)
ValueError: in converted code:

    :3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/variables.py:260 __call__
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/variables.py:254 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/variables.py:65 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py:502 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

下面是没有歧义的代码

v = tf.Variable(1.0)

@tf.function
def f(x):
  return v.assign_add(x)

print(f(1.0))  # 2.0
print(f(2.0))  # 4.0

# 输出
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)

只要保证只初始化一次的代码, 依然可以在tf.function中编写

class C:
  pass

obj = C()
obj.v = None

@tf.function
def g(x):
  if obj.v is None:
    obj.v = tf.Variable(1.0)
  return obj.v.assign_add(x)

print(g(1.0))  # 2.0
print(g(2.0))  # 4.0

# 输出
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)

参数初始化依赖于函数参数和其他参数, 可以使用相同的方法生成控制依赖并查明初始化顺序。

state = []
@tf.function
def fn(x):
  if not state:
    state.append(tf.Variable(2.0 * x))
    state.append(tf.Variable(state[0] * 3.0))
  return state[0] * x * state[1]

print(fn(tf.constant(1.0)))
print(fn(tf.constant(3.0)))

# 输出
tf.Tensor(12.0, shape=(), dtype=float32)
tf.Tensor(36.0, shape=(), dtype=float32)

使用 AutoGraph

tf.function中完全继成了autofraph库,他可以重写条件和循环,这取决于张量在图中的动态运行。

tf.cond和tf.while_loop同样可以和tf.function一起运行,但是使用控制流更容易编写和易读。

# Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))

# 输出

[0.626656175 0.382405877 0.784070969 0.693427086 0.150599718]
[0.555745482 0.364795 0.655037165 0.600179076 0.149471417]
[0.50481391 0.349430591 0.575051188 0.537176967 0.148368135]
[0.46589461 0.335870445 0.519059181 0.490848064 0.147288963]
[0.434876293 0.323785722 0.476973534 0.454889268 0.146233037]
[0.409388453 0.312925965 0.443816543 0.425909698 0.145199522]
[0.387953311 0.303096592 0.416802973 0.401897341 0.144187644]
[0.369594455 0.294143856 0.394233704 0.381571233 0.143196672]
[0.35363692 0.285944343 0.375004321 0.36407122 0.142225876]
[0.339597 0.278397709 0.358361155 0.348795027 0.141274586]
[0.327117532 0.271421462 0.34376967 0.335306466 0.140342161]
[0.31592837 0.264947 0.330838621 0.323280782 0.139427975]
[0.305820704 0.258916765 0.319274098 0.312470376 0.138531446]
[0.296630263 0.25328207 0.308850408 0.302682817 0.137652025]
[0.2882258 0.248001367 0.299390912 0.293765813 0.136789158]
[0.280501 0.243039 0.290755093 0.285597175 0.13594234]
[0.273368716 0.238364145 0.282829642 0.278077424 0.135111064]
[0.266756624 0.233950019 0.275521934 0.27112475 0.134294882]
[0.260604292 0.229773208 0.268755466 0.264671087 0.133493334]
[0.254860669 0.225813136 0.262466401 0.258659333 0.132705986]
[0.249482289 0.22205165 0.256600976 0.253041118 0.131932423]
[0.244431943 0.21847266 0.251113564 0.247775227 0.13117224]
[0.239677504 0.215061843 0.245965138 0.242826208 0.130425066]
[0.235191122 0.211806417 0.241122097 0.238163441 0.129690528]
[0.230948448 0.20869489 0.236555338 0.233760297 0.128968269]
[0.22692816 0.205716968 0.2322395 0.2295935 0.12825796]
[0.223111436 0.202863321 0.228152364 0.225642592 0.127559274]

如果你对他怎么实现的感兴趣可以查看实现代码

def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

print(tf.autograph.to_code(f))

# 输出
def tf__f(x):
  do_return = False
  retval_ = ag__.UndefinedReturnValue()
  with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:

    def get_state():
      return ()

    def set_state(_):
      pass

    def loop_body(x):
      ag__.converted_call(tf.print, (x,), None, fscope)
      x = ag__.converted_call(tf.tanh, (x,), None, fscope)
      return x,

    def loop_test(x):
      return ag__.converted_call(tf.reduce_sum, (x,), None, fscope) > 1
    x, = ag__.while_stmt(loop_test, loop_body, get_state, set_state, (x,), ('x',), ())
    do_return = True
    retval_ = fscope.mark_return_value(x)
  do_return,
  return ag__.retval(retval_)

AutoGraph: 条件

AutoGraph可以使用tf.cond替换if声明。 这种替换发生在当条件判断是一个张量时,这种条件会在追踪时运行。

下面这段代码可以检查计算图是否使用了tf.cond:

def test_tf_cond(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'cond' for node in g.as_graph_def().node):
    print("{}({}) uses tf.cond.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) executes normally.".format(
        f.__name__, ', '.join(map(str, args))))

  print("  result: ",f(*args).numpy())

当if条件判断是Tensor时会发生替换,另一方面条件判断发生在追踪时。

传递一个Python的True给条件判断

@tf.function
def dropout(x, training=True):
  if training:
    x = tf.nn.dropout(x, rate=0.5)
  return x

test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), True)

# 输出
dropout(tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(10,), dtype=float32), True) executes normally.
  result:  [0. 2. 0. 2. 2. 2. 0. 0. 0. 0.]

But passing a tensor replaces the python if with a tf.cond:

test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), tf.constant(True))
# 输出
dropout(tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(10,), dtype=float32), tf.Tensor(True, shape=(), dtype=bool)) uses tf.cond.
  result:  [0. 0. 0. 2. 0. 0. 2. 0. 2. 2.]

tf.cond有一些细节

在条件判断和执行分支的时候都依赖于追溯机制,也依赖于条件。如果使用Pyhton原生代码会输出不一样的结果

@tf.function
def f(x):
  if x > 0:
    x = x + 1.
    print("Tracing `then` branch")
  else:
    x = x - 1.
    print("Tracing `else` branch")
  return x


f(-1.0).numpy()
# 输出
Tracing `else` branch

-2.0


f(1.0).numpy()
# 输出

Tracing `then` branch

2.0

f(tf.constant(1.0)).numpy()
# 输出

Tracing `then` branch
Tracing `else` branch # 当参数使用张量时 两个分支的的python代码都被执行了

2.0

如果一个分支里创建了一个张量,那另一个分支里也必须创建张量

@tf.function
def f():
  if tf.constant(True):
    x = tf.ones([3, 3])
  return x

# Throws an error because both branches need to define `x`.
with assert_raises(ValueError):
  f()

# 输出

Caught expected exception 
  :

Traceback (most recent call last):
  File "", line 8, in assert_raises
    yield
  File "", line 9, in 
    f()
ValueError: in converted code:

    :3 f  *
        if tf.constant(True):
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:918 if_stmt
        basic_symbol_names, composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:956 tf_if_stmt
        error_checking_orelse)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507 new_func
        return func(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/control_flow_ops.py:1174 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:90 cond_v2
        op_return_value=pred)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:978 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:949 error_checking_orelse
        result[orelse_branch] = orelse()
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:987 wrapper
        new_vars = func()
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:1013 wrapper
        tuple(s.symbol_name for s in undefined)))

    ValueError: The following symbols must also be initialized in the else branch: ('x',). Alternatively, you may initialize them before the if statement.

如果想确保一段控制流从来没有被autograph转换,可以显式的转换成python类型 这样会爆出错误

@tf.function
def f(x, y):
  if bool(x):
    y = y + 1.
    print("Tracing `then` branch")
  else:
    y = y - 1.
    print("Tracing `else` branch")
  return y

f(True, 0).numpy()
# 输出
Tracing `then` branch

1.0

f(False, 0).numpy()
# 输出
Tracing `else` branch

-1.0

with assert_raises(TypeError):
  f(tf.constant(True), 0.0)

# 输出
Caught expected exception 
  :

Traceback (most recent call last):
  File "", line 8, in assert_raises
    yield
  File "", line 2, in 
    f(tf.constant(True), 0.0)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in converted code:

    :3 f  *
        if bool(x):
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/impl/api.py:416 converted_call
        return py_builtins.overload_of(f)(*args)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:757 __bool__
        self._disallow_bool_casting()
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:523 _disallow_bool_casting
        "using a `tf.Tensor` as a Python `bool`")
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:510 _disallow_when_autograph_enabled
        " decorating it directly with @tf.function.".format(task))

    OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.

自动图和循环

AutoGraph对于循环有一些简单的规则

  • for:如果一个可迭代的容器是一个张量进行转换
  • while:如果while循环的判断条件是一个张量

如果循环被准还, 将被tf.while_loop动态的展开,或者for x in tf.data.Dataset这种特殊的例子会被转换成tf.data.Dataset.reduce

如果没被转换就会被静态的展开

def test_dynamically_unrolled(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'while' for node in g.as_graph_def().node):
    print("{}({}) uses tf.while_loop.".format(
        f.__name__, ', '.join(map(str, args))))
  elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):
    print("{}({}) uses tf.data.Dataset.reduce.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) gets unrolled.".format(
        f.__name__, ', '.join(map(str, args))))

For循环

这个例子证明tf.function被静态展开

@tf.function
def for_in_range():
  x = 0
  for i in range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_range)
# 输出
for_in_range() gets unrolled.

@tf.function
def for_in_tfrange():
  x = tf.constant(0, dtype=tf.int32)
  for i in tf.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfrange)
# 输出
for_in_tfrange() uses tf.while_loop.

@tf.function
def for_in_tfdataset():
  x = tf.constant(0, dtype=tf.int64)
  for i in tf.data.Dataset.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfdataset)
# 输出
for_in_tfdataset() uses tf.data.Dataset.reduce.


@tf.function
def while_py_cond():
  x = 5
  while x > 0:
    x -= 1
  return x

test_dynamically_unrolled(while_py_cond)
# 输出
while_py_cond() gets unrolled.

@tf.function
def while_tf_cond():
  x = tf.constant(5)
  while x > 0:
    x -= 1
  return x

test_dynamically_unrolled(while_tf_cond)
# 输出
while_tf_cond() uses tf.while_loop.

如果有break或者提前return语句依赖于tensor判断,那上一级的判断或循环也应当是个tensor

比较下面的例子

@tf.function
def while_py_true_py_break(x):
  while True:  # py true
    if x == 0: # py break
      break
    x -= 1
  return x

test_dynamically_unrolled(while_py_true_py_break, 5)

# 输出
while_py_true_py_break(5) gets unrolled.

@tf.function
def buggy_while_py_true_tf_break(x):
  while True:   # py true
    if tf.equal(x, 0): # tf break
      break
    x -= 1
  return x

with assert_raises(TypeError):
  test_dynamically_unrolled(buggy_while_py_true_tf_break, 5)

# 输出
Caught expected exception 
  :

Traceback (most recent call last):
  File "", line 8, in assert_raises
    yield
  File "", line 10, in 
    test_dynamically_unrolled(buggy_while_py_true_tf_break, 5)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in converted code:

    :3 buggy_while_py_true_tf_break  *
        while True:   # py true
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:755 while_stmt
        return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:870 _py_while_stmt
        while test(*loop_vars):
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:757 __bool__
        self._disallow_bool_casting()
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:523 _disallow_bool_casting
        "using a `tf.Tensor` as a Python `bool`")
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:510 _disallow_when_autograph_enabled
        " decorating it directly with @tf.function.".format(task))

    OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.

@tf.function
def while_tf_true_tf_break(x):
  while tf.constant(True): # tf true
    if x == 0:  # py break
      break
    x -= 1
  return x

test_dynamically_unrolled(while_tf_true_tf_break, 5)

# 输出
while_tf_true_tf_break(5) uses tf.while_loop.

@tf.function
def buggy_py_for_tf_break():
  x = 0
  for i in range(5):  # py for
    if tf.equal(i, 3): # tf break
      break
    x += i
  return x

with assert_raises(TypeError):
  test_dynamically_unrolled(buggy_py_for_tf_break)

# 输出 
Caught expected exception 
  :

Traceback (most recent call last):
  File "", line 8, in assert_raises
    yield
  File "", line 11, in 
    test_dynamically_unrolled(buggy_py_for_tf_break)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in converted code:

    :4 buggy_py_for_tf_break  *
        for i in range(5):  # py for
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:339 for_stmt
        return _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:348 _py_for_stmt
        if extra_test is not None and not extra_test(*state):
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:757 __bool__
        self._disallow_bool_casting()
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:523 _disallow_bool_casting
        "using a `tf.Tensor` as a Python `bool`")
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:510 _disallow_when_autograph_enabled
        " decorating it directly with @tf.function.".format(task))

    OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.


@tf.function
def tf_for_py_break():
  x = 0
  for i in tf.range(5): # tf for
    if i == 3:  # py break
      break
    x += i
  return x

test_dynamically_unrolled(tf_for_py_break)

# 输出
tf_for_py_break() uses tf.while_loop.

如果要累加结果,可能需要用到tf.TensorArray

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))

# 输出

一些问题

tf.cond, tf.while同样有一些小问题

0循环

如果一个循环执行0次

@tf.function
def buggy_loop_var_uninitialized():
  for i in tf.range(3):
    x = i
  return x

with assert_raises(ValueError):
  buggy_loop_var_uninitialized()

# 输出
Caught expected exception 
  :

Traceback (most recent call last):
  File "", line 8, in assert_raises
    yield
  File "", line 8, in 
    buggy_loop_var_uninitialized()
ValueError: in converted code:

    :3 buggy_loop_var_uninitialized  *
        for i in tf.range(3):
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:315 for_stmt
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:419 _tf_range_for_stmt
        _disallow_undefs_into_loop(*init_vars)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:97 _disallow_undefs_into_loop
        ' before the loop: {}'.format(tuple(s.symbol_name for s in undefined)))

    ValueError: TensorFlow requires that the following symbols must be defined before the loop: ('x',)

正确版本

@tf.function
def f():
  x = tf.constant(0)
  for i in tf.range(3):
    x = i
  return x

f()
# 输出

保持类型和形状一致

在迭代中要保持参数形状和类型不变

错误示范

@tf.function
def buggy_loop_type_changes():
  x = tf.constant(0, dtype=tf.float32)
  for i in tf.range(3): # Yields tensors of type tf.int32...
    x = i
  return x

with assert_raises(TypeError):
  buggy_loop_type_changes()

# 输出
Caught expected exception 
  :

Traceback (most recent call last):
  File "", line 8, in assert_raises
    yield
  File "", line 9, in 
    buggy_loop_type_changes()
TypeError: in converted code:

    :4 buggy_loop_type_changes  *
        for i in tf.range(3): # Yields tensors of type tf.int32...
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:315 for_stmt
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:478 _tf_range_for_stmt
        opts=opts,
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:794 _tf_while_stmt
        aug_init_vars, **opts)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/control_flow_ops.py:2675 while_loop
        back_prop=back_prop)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/while_v2.py:194 while_loop
        add_control_dependencies=add_control_dependencies)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:978 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/while_v2.py:172 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:784 aug_body
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:195 _verify_tf_loop_vars
        first_iter_var)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/util/nest.py:568 map_structure
        structure[0], [func(*x) for x in entries],
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/util/nest.py:568 
        structure[0], [func(*x) for x in entries],
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:179 _check_same_type
        first_iter_var.dtype.name,

    TypeError: "x" has dtype float32 before the loop, but dtype int32 after one iteration. TensorFlow control flow requires it stays the same.

改变形状的错误示范

@tf.function
def buggy_concat():
  x = tf.ones([0, 10])
  for i in tf.range(5):
    x = tf.concat([x, tf.ones([1, 10])], axis=0)
  return x

with assert_raises(ValueError):
  buggy_concat()

# 输出
Caught expected exception 
  :

Traceback (most recent call last):
  File "", line 8, in assert_raises
    yield
  File "", line 9, in 
    buggy_concat()
ValueError: in converted code:

    :4 buggy_concat  *
        for i in tf.range(5):
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:315 for_stmt
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:478 _tf_range_for_stmt
        opts=opts,
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:794 _tf_while_stmt
        aug_init_vars, **opts)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/control_flow_ops.py:2675 while_loop
        back_prop=back_prop)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/while_v2.py:194 while_loop
        add_control_dependencies=add_control_dependencies)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:978 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/while_v2.py:172 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:784 aug_body
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:195 _verify_tf_loop_vars
        first_iter_var)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/util/nest.py:568 map_structure
        structure[0], [func(*x) for x in entries],
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/util/nest.py:568 
        structure[0], [func(*x) for x in entries],
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:191 _check_same_type
        first_iter_shape))

    ValueError: "x" has shape (0, 10) before the loop, but shape (1, 10) after one iteration. TensorFlow control flow requires it stays the same or be more specific.


@tf.function
def concat_with_padding():
  x = tf.zeros([5, 10])
  for i in tf.range(5):
    x = tf.concat([x[:i], tf.ones([1, 10]), tf.zeros([4-i, 10])], axis=0)
    x.set_shape([5, 10])
  return x

concat_with_padding()

# 输出




参考文献 https://tensorflow.google.cn/tutorials/customization/performance

你可能感兴趣的:(挑选出tensor中等于0的索引_[TensorFlow2.0文档翻译] 使用tf.function提升性能)