在 TensorFlow 2.0 中,默认情况下,Eager Execution 处于启用状态。这为您提供一个非常直观灵活的界面,可以提升运行一次性操作的简易性和速度,但会降低性能和可部署性。
为了获得峰值性能并使您的模型可以部署在任何位置,我们提供 tf.function,您可以将其用作工具,从程序中生成图表。
from __future__ import absolute_import, division, print_function, unicode_literals
!pip install -q tensorflow==2.0.0-alpha0
import tensorflow as tf
# 一个函数相当于一项操作
@tf.function
def add(a, b):
return a + b
add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
array([[2., 2.], [2., 2.]], dtype=float32)> 您定义的 tf.function 相当于核心的 TensorFlow 操作:您可以立即执行该函数、可以在图表中使用该函数、该函数具有梯度,等等。 # 函数具有梯度 # 您可以在函数中使用函数 array([[3., 3.], [3., 3.], [3., 3.]], dtype=float32)> 多态性 tf.function 试图成为和 Python 函数一样通用的函数。您可以使用各种签名调用 Python 函数,并且 Python 通常会进行一些合理的操作。即使 tf.function 生成的底层 TensorFlow 图表只适用于其签名中的特定类型,也会为您处理此类多态。 您可以调用具有不同类型参数的函数来查看发生的操作。 # 函数具有多态性 add 1 tf.Tensor(2, shape=(), dtype=int32) add 1.1 tf.Tensor(2.2, shape=(), dtype=float32) add string tensor tf.Tensor(b'aa', shape=(), dtype=string) # 对于含有许多小操作的图表而言,函数的运行速度比即时代码更快 Eager conv: 0.20972437999444082 Function conv: 0.21063927400973625 Note how there's not much difference in performance for convolutions eager lstm: 0.033881522991578095 function lstm: 0.005326402999344282 tf.function 中的状态 在一般数据流图表中,tf.function 作为编程模型有一个非常有吸引力的函数属性,即函数可以为运行时提供关于代码期望行为定义的更多信息。 例如,在编写对相同变量具有多次读取和写入的代码时,数据流图表可能不会自然地对操作的最初期望顺序进行编码。然而,在 tf.function 中,由于我们要转换从 Python 追踪的代码,所以我们知道期望执行顺序。 这意味着我们无需添加手动控制依赖项;tf.function 足够智能,可以为代码添加必要的极小集和充分的控制依赖项,使其能够正确运行。 # 自动控制依赖项 变量 我们可以使用与利用代码的期望执行顺序相同的方法大大简化 tf.function 中的变量创建和使用过程。然而,有一点需要注意,如果多次立即调用变量,或多次评估变量的输出张量,则我们使用变量编写出的代码行为可能有所不同。 简单示例如下: @tf.function 如果使用 Eager Execution 运行此代码,您将始终得到答案“2”,但如果在图表上下文中反复评估从 f(1.) 中获得的 Tensor,您将得到渐增的数字。 所以 tf.function 不允许您写入此类代码。 # 但无歧义代码运行正常 # 您也可以在 tf.function 中创建变量,只要我们能够验证该类变量即可 # 变量初始化器可以依赖函数参数和其他 WARNING:Logging before flag parsing goes to stderr. W0418 23:39:09.505958 139706314610432 tf_logging.py:161] Entity W0418 23:39:09.517445 139706314610432 tf_logging.py:161] Entity WARNING:Entity WARNING:Entity 控制流和 AutoGraph 在 tf.cond 和 tf.while_loop 继续使用 tf.function 工作的同时,我们以 Python 代码的轻量级编译为基础,提供更好的替代方案。 AutoGraph 库与 tf.function 实现完整集成,其将重写依赖 Tensors 的条件语句和循环语句,以在图表中动态运行。 # 简单循环语句 [0.690678835 0.687305927 0.280717611 ... 0.481444716 0.331221104 0.0514520407] [0.598417938 0.596248507 0.273569077 ... 0.447399884 0.31961754 0.0514066778] [0.535922825 0.534374714 0.26694271 ... 0.419759363 0.309161037 0.0513614379] [0.489895523 0.488718063 0.260777712 ... 0.396727681 0.299673647 0.0513163097] [0.454133481 0.453198373 0.255022794 ... 0.377145559 0.291013896 0.0512713082] [0.425290793 0.424524486 0.249634281 ... 0.360225976 0.283067733 0.0512264259] [0.401378244 0.400735199 0.244574845 ... 0.345413029 0.275741935 0.0511816591] [0.381127626 0.380577862 0.239812165 ... 0.332301289 0.268959552 0.051137004] [0.363686323 0.363209188 0.235318303 ... 0.320587069 0.26265642 0.0510924757] [0.348456889 0.34803763 0.231068835 ... 0.310037643 0.256778479 0.0510480627] [0.335006297 0.334634066 0.227042317 ... 0.300471336 0.251279861 0.0510037579] [0.323011965 0.322678506 0.223219901 ... 0.291743845 0.246121347 0.0509595685] [0.312227786 0.311926812 0.219584852 ... 0.28373903 0.241269216 0.0509155] [0.302462429 0.302188963 0.216122329 ... 0.276362091 0.236694187 0.0508715473] [0.293564439 0.293314546 0.212819085 ... 0.269534767 0.232370824 0.0508277] [0.285412192 0.285182655 0.209663227 ... 0.263191849 0.228276819 0.0507839732] [0.277906716 0.277694911 0.206644118 ... 0.257278532 0.224392563 0.0507403575] [0.270966589 0.270770341 0.203752115 ... 0.251748294 0.220700681 0.0506968535] [0.264523983 0.264341474 0.200978562 ... 0.246561363 0.217185751 0.0506534651] [0.258522063 0.258351713 0.198315561 ... 0.241683558 0.213834107 0.0506101772] [0.252912641 0.252753168 0.195755959 ... 0.237085283 0.210633427 0.0505670048] [0.247654602 0.247504905 0.193293214 ... 0.232740775 0.207572699 0.0505239442] [0.242712677 0.242571801 0.190921336 ... 0.228627458 0.204641983 0.0504809953] [0.238056332 0.237923414 0.188634917 ... 0.2247255 0.201832339 0.0504381545] [0.233659014 0.233533338 0.18642889 ... 0.221017376 0.199135616 0.0503954217] [0.229497537 0.229378462 0.184298649 ... 0.217487514 0.196544439 0.0503527932] [0.225551486 0.225438461 0.18223998 ... 0.214122042 0.194052115 0.0503102802] [0.22180286 0.221695408 0.180248916 ... 0.210908577 0.191652477 0.0502678677] [0.218235701 0.21813339 0.178321853 ... 0.207835972 0.189339921 0.0502255671] [0.214835808 0.21473822 0.176455453 ... 0.20489423 0.187109306 0.0501833707] [0.211590484 0.211497262 0.174646571 ... 0.202074304 0.18495588 0.0501412824] [0.208488345 0.208399177 0.172892302 ... 0.199367985 0.182875291 0.0500993] [0.20551914 0.205433741 0.171189949 ... 0.196767837 0.180863515 0.0500574186] [0.202673614 0.202591732 0.169537008 ... 0.194267079 0.178916842 0.0500156432] [0.199943423 0.19986479 0.167931125 ... 0.191859543 0.177031845 0.0499739796] [0.197320923 0.197245359 0.166370124 ... 0.189539552 0.17520532 0.0499324165] [0.19479923 0.194726542 0.164851919 ... 0.187301934 0.173434287 0.0498909578] [0.192372054 0.192302063 0.163374603 ... 0.185141906 0.171716 0.0498496] [0.1900336 0.189966142 0.161936387 ... 0.183055103 0.170047909 0.0498083457] [0.187778607 0.187713534 0.160535559 ... 0.181037456 0.168427587 0.0497671925] [0.185602203 0.18553938 0.159170523 ... 0.17908524 0.166852787 0.0497261435] [0.183499932 0.183439225 0.15783979 ... 0.177194953 0.16532144 0.0496851951] [0.181467682 0.181408942 0.156541929 ... 0.175363421 0.163831577 0.049644351] [0.179501608 0.17944476 0.155275628 ... 0.173587635 0.162381351 0.0496036038] [0.177598223 0.177543178 0.154039606 ... 0.171864837 0.160969034 0.0495629534] [0.175754249 0.175700903 0.152832672 ... 0.17019242 0.159592986 0.0495224036] [0.173966661 0.173914909 0.151653722 ... 0.168567985 0.158251703 0.0494819507] [0.172232628 0.172182426 0.150501683 ... 0.166989282 0.156943724 0.0494416021] [0.170549557 0.1705008 0.149375558 ... 0.165454194 0.155667722 0.0494013466] [0.168914959 0.168867588 0.148274377 ... 0.163960755 0.154422387 0.0493611954] [0.16732657 0.167280525 0.147197232 ... 0.162507117 0.153206512 0.0493211411] [0.165782228 0.165737465 0.146143243 ... 0.161091521 0.152018949 0.0492811799] [0.164279968 0.164236397 0.145111606 ... 0.159712359 0.150858626 0.0492413193] [0.162817866 0.162775457 0.144101545 ... 0.158368081 0.149724513 0.0492015518] [0.161394194 0.161352888 0.143112317 ... 0.157057241 0.148615628 0.0491618849] [0.160007298 0.159967035 0.14214322 ... 0.155778468 0.147531062 0.0491223149] [0.158655599 0.158616349 0.141193554 ... 0.154530466 0.146469936 0.049082838] [0.157337651 0.15729937 0.140262708 ... 0.153312042 0.145431399 0.049043458] [0.156052053 0.156014711 0.139350057 ... 0.152122036 0.144414678 0.0490041673] [0.154797524 0.154761076 0.138455018 ... 0.150959358 0.143419027 0.0489649773] [0.153572813 0.153537214 0.137577027 ... 0.149822965 0.142443702 0.0489258766] [0.152376771 0.152341992 0.136715531 ... 0.148711905 0.141488045 0.0488868728] [0.151208282 0.151174292 0.135870054 ... 0.147625253 0.140551403 0.0488479622] [0.150066301 0.150033072 0.135040075 ... 0.146562085 0.139633134 0.0488091446] [0.148949862 0.148917362 0.13422516 ... 0.145521596 0.138732657 0.0487704165] [0.147858009 0.14782621 0.133424833 ... 0.144503 0.137849391 0.0487317815] [0.146789849 0.146758735 0.132638663 ... 0.143505514 0.136982813 0.048693236] [0.145744532 0.145714089 0.131866246 ... 0.142528445 0.136132374 0.0486547835] [0.144721285 0.144691452 0.131107181 ... 0.14157109 0.135297611 0.0486164242] [0.143719301 0.143690094 0.130361095 ... 0.140632793 0.134478047 0.048578158] [0.14273788 0.14270927 0.129627615 ... 0.139712915 0.133673206 0.0485399812] [0.141776308 0.141748279 0.128906399 ... 0.138810888 0.132882655 0.0485018939] [0.140833959 0.140806481 0.128197089 ... 0.137926146 0.132105991 0.0484638959] [0.139910176 0.13988322 0.127499372 ... 0.137058124 0.131342798 0.0484259836] [0.13900435 0.138977915 0.126812935 ... 0.136206314 0.130592704 0.048388157] [0.138115913 0.13808997 0.12613748 ... 0.135370195 0.12985532 0.0483504198] [0.137244314 0.137218863 0.125472724 ... 0.134549305 0.129130304 0.0483127683] [0.136389032 0.136364058 0.124818385 ... 0.133743197 0.128417313 0.0482752062] [0.135549575 0.135525048 0.124174185 ... 0.132951409 0.12771602 0.0482377335] [0.134725437 0.134701341 0.123539865 ... 0.132173553 0.127026111 0.0482003503] [0.133916169 0.133892506 0.122915179 ... 0.131409198 0.126347259 0.0481630526] [0.133121327 0.133098081 0.122299887 ... 0.130657956 0.125679195 0.0481258407] [0.132340491 0.132317647 0.121693745 ... 0.129919484 0.125021622 0.0480887182] [0.13157326 0.131550804 0.121096537 ... 0.129193395 0.124374263 0.0480516776] [0.130819231 0.130797148 0.120508038 ... 0.128479362 0.123736873 0.0480147265] [0.130078033 0.130056322 0.119928055 ... 0.127777055 0.123109199 0.0479778573] [0.129349306 0.129327953 0.119356371 ... 0.127086148 0.122490987 0.0479410738] [0.128632709 0.128611699 0.118792787 ... 0.126406357 0.121882014 0.0479043722] [0.127927899 0.127907217 0.118237123 ... 0.125737369 0.121282049 0.0478677601] [0.127234563 0.127214208 0.1176892 ... 0.125078887 0.120690845 0.0478312299] [0.126552388 0.126532361 0.117148817 ... 0.124430671 0.120108217 0.0477947854] [0.125881076 0.125861362 0.116615817 ... 0.123792425 0.119533956 0.0477584228] [0.125220343 0.125200942 0.116090037 ... 0.123163909 0.118967861 0.0477221459] [0.124569915 0.124550819 0.115571305 ... 0.122544892 0.118409738 0.0476859473] [0.123929545 0.123910733 0.115059465 ... 0.121935107 0.117859408 0.0476498269] [0.123298958 0.123280421 0.114554383 ... 0.121334352 0.117316693 0.0476137921] [0.1226779 0.122659646 0.114055909 ... 0.120742403 0.116781399 0.0475778393] [0.12206614 0.122048169 0.113563888 ... 0.120159045 0.116253398 0.0475419648] [0.121463455 0.121445753 0.113078192 ... 0.119584054 0.115732491 0.0475061722] [0.120869622 0.12085218 0.112598673 ... 0.119017266 0.115218528 0.0474704653] [0.120284423 0.12026722 0.11212521 ... 0.118458457 0.114711344 0.0474348329] [0.119707651 0.119690686 0.111657664 ... 0.11790745 0.114210814 0.0473992862] [0.119139098 0.119122371 0.111195929 ... 0.117364064 0.113716781 0.0473638177] [0.118578568 0.118562095 0.110739879 ... 0.116828129 0.113229126 0.0473284237] [0.118025899 0.118009649 0.110289395 ... 0.116299465 0.112747692 0.0472931154] [0.117480896 0.117464855 0.109844379 ... 0.11577794 0.112272345 0.0472578816] array([0.11694337, 0.11692756, 0.10940471, 0.11663772, 0.04113298, 0.11698803, 0.10682372, 0.11526338, 0.11180297, 0.04722273], dtype=float32)> # 如您有意了解详情,可以查看 AutoGraph 生成的代码。 from __future__ import print_function def tf__f(x): try: with ag__.function_scope('f'): do_return = False retval_ = None def loop_test(x_1): with ag__.function_scope('loop_test'): return ag__.gt(ag__.converted_call('reduce_sum', tf, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=ag__.Feature.ALL, internal_convert_user_code=True), (x_1,), {}), 1) def loop_body(x_1): with ag__.function_scope('loop_body'): with ag__.utils.control_dependency_on_returns(ag__.converted_call('print', tf, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=ag__.Feature.ALL, internal_convert_user_code=True), (x_1,), {})): tf_1, x = ag__.utils.alias_tensors(tf, x_1) x = ag__.converted_call('tanh', tf_1, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=ag__.Feature.ALL, internal_convert_user_code=True), (x,), {}) return x, x, = ag__.while_stmt(loop_test, loop_body, (x,), (tf, x, ag__)) do_return = True retval_ = x return retval_ except: ag__.rewrite_graph_construction_error(ag_source_map__) tf__f.autograph_info__ = {} 为控制 AutoGraph,请记住该库只影响 Python 中的基本控制流构造(if、for、while、break 等),并且其只在谓词为 Tensor 时才会更改这些构造。 因此在下面的例子中,第一个循环经过静态展开,而第二个循环经过动态转换: @tf.function 同样地,为确保打印输出和断言动态发生,请使用 tf.print 和 tf.assert: @tf.function 0 1 2 3 4 5 6 7 8 9 最后,AutoGraph 无法将任意 Python 代码编译为 TensorFlow 图。具体来说,您动态使用的数据结构仍需为 TensorFlow 数据结构。 因此,举例而言,在循环中累积数据的最佳方法仍然是使用 tf.TensorArray: @tf.function array([ 20., 40., 80., 160., 320., 640., 1280., 2560., 5120., 10240.], dtype=float32)>
@tf.function
def add(a, b):
return a + b
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
@tf.function
def add(a):
return a + a
print("add 1", add(1))
print("add 1.1", add(1.1))
print("add string tensor", add(tf.constant("a")))
c = add.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
c(a=tf.constant("a")) # aa
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
image = tf.zeros([1, 200, 200, 100])
# 预热
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
lstm_cell = tf.keras.layers.LSTMCell(10)
@tf.function
def lstm_fn(input, state):
return lstm_cell(input, state)
input = tf.zeros([10, 10])
state = [tf.zeros([10, 10])] * 2
# 预热
lstm_cell(input, state); lstm_fn(input, state)
print("eager lstm:", timeit.timeit(lambda: lstm_cell(input, state), number=10))
print("function lstm:", timeit.timeit(lambda: lstm_fn(input, state), number=10))
a = tf.Variable(1.0)
b = tf.Variable(2.0)
@tf.function
def f(x, y):
a.assign(y * b)
b.assign_add(x * a)
return a + b
f(1.0, 2.0) # 10.0
def f(x):
v = tf.Variable(1.0)
v.assign_add(x)
return v
f(1.) # 请注意:中断,将抛出异常
v = tf.Variable(1.0)
@tf.function
def f(x):
return v.assign_add(x)
f(1.0) # 2.0
f(2.0) # 4.0
# 您仅能在第一次执行函数时创建该类变量。
class C: pass
obj = C(); obj.v = None
@tf.function
def g(x):
if obj.v is None:
obj.v = tf.Variable(1.0)
return obj.v.assign_add(x)
g(1.0) # 2.0
g(2.0) # 4.0
# 变量值。我们可以用与生成控制依赖项相同的方法
# 确定正确的初始化顺序。
state = []
@tf.function
def fn(x):
if not state:
state.append(tf.Variable(2.0 * x))
state.append(tf.Variable(state[0] * 3.0))
return state[0] * x * state[1]
fn(tf.constant(1.0))
fn(tf.constant(3.0))
@tf.function
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([10]))
# 不过,这感觉像是在阅读汇编语言。
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
print(tf.autograph.to_code(f))
def f(x):
for i in range(10): # 静态 Python 循环,我们不会转换此循环
do_stuff()
for i in tf.range(10): # 依赖于张量,我们会转换此循环
def f(x):
for i in tf.range(10):
tf.print(i)
tf.Assert(i < 10, ["a"])
x += x
return x
f(10)
def f(x):
ta = tf.TensorArray(tf.float32, size=10)
for i in tf.range(10):
x += x
ta = ta.write(i, x)
return ta.stack()
f(10.0)