本文是从我做的一份ppt改编而来,是我读tensorflow源码后,对tensorflow数据流机的一些理解。由于内容很多来自于paper,我懒得翻译,索性就用英文做了ppt。本文的一些图片有些模糊,遇到看不清的地方,可以参考我那份原始的ppt。本文的一个亮点是,以一段代码为例,我用processon 画出了tensorflow 计算图的计算流程。
This ppt discuss about tensorflow control flow and how tensorflow compute TF graph.
This material is distilled from:
1、TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems
2、TensorFlow: A System for Large-Scale Machine Learning
3、Implementation of Control Flow in TensorFlow
4、Dynamic Control Flow in Large-Scale Machine Learning
5、https://github.com/tensorflow/tensorflowhttps://arxiv.org/pdf/1805.01772.pdf)
when you run tensorflow using python front end,work can be divided to four steps.
•Define a model using TF code ,typically in python
•Tensorflow front end implicitly define a computation graph
•Launch the graph through a session
•Tensorflow runtime executes the graph
In a way , tensorflow can be regarded as a programing language or virtual machine or both,with tensorflow front end regarded as programing language, tensorflow graph as a kind of IR(intermediate representation)and tensorflow back end runtime as the virtual machine.
As is the case with any serious programing language, Control flow is an integral part of tensorflow programing interfaces.
Exactly,Control flow is what make TF graph more expressive than DAG by making conditional computation and iterative computation possible .
•tf.cond
•tf.while_loop
•tf.group
•tf.tuple
•tf.case
Definition of those api can be found in tensorflow/python/ops/control_flow_ops.py
Control flow is supported through five primitive operations . Control flow structs such as tf.cond and tf.while_loop are compiled to a subgraph with these primitive operations.
These five primitive operations are treated specially by graph executors during the execution of a TF graph . AS we will see later , These five primitive operations play a very important role in TF graph execution. The execution of a TF graph is orchestrated by these five primitive operations.
In TensorFlow, every op is executed in an execution frame, and the control-flow primitives are responsible for creating and managing these execution frames.
Intuitively, for each while loop, the TensorFlow runtime sets up an execution frame and runs all the ops belonging to the while loop inside the execution frame. While loop can be nested,ops in while loops run in nested execution frames.Nested while loop may have multiple execution frame instances,one for each iteration of the parent execution frame.Ops from different execution frames can run in parallel as long as there is no dependency between them.
•Switch: A Switch operator forwards the input tensor d to one of its outputs depending on the boolean tensor of the control input p. A Switch is enabled for execution when both its inputs are available.
•Merge:A Merge operator forwards one of its available inputs to its output. A Merge is enabled for execution when any of its inputs is available. It is unspecified which available input it outputs if there are multiple inputs available.
•Enter(name): An Enter operator forwards its input to the execution frame that is uniquely identified by the given name. This Enter op is used to pass a tensor in one execution frame to a child execution frame. There can be multiple Enter ops to the same child execution frame, each making a tensor available (asynchronously) in that child execution frame. An Enter is enabled for execution when its input is available. A new execution frame is instantiated in the TensorFlow runtime when the first Enter op to that frame is executed.
•Exit: An Exit operator forwards a value from an execution frame to its parent execution frame. This Exit op is used to return a tensor computed in a child execution frame back to its parent frame. There can be multiple Exit ops to the parent frame, each asynchronously passing a tensor back to the parent frame. An Exit is enabled when its input is available.
•NextIteration: A NextIteration operator forwards its input to the next iteration in the current execution frame. The TensorFlow runtime keeps track of iterations in an execution frame. Any op executed in an execution frame has a unique iteration id, which allows us to uniquely identify different invocations of the same op in an iterative computation. Note that there can be multiple NextIteration ops in an execution frame. The TensorFlow runtime starts iteration N+1(create an iterationstate object) when the first NextIteration op is executed at iteration N. As more tensors enter an iteration by executing NextIteration ops, more ops in that iteration will be ready for execution. A NextIteration is enabled when its input is available.
The executor follows the following rules of execution
Switch(p, d) = (r1 , r2 ) :
r1 = (value(d), p || is_dead(d), tag(d))
r2 = (value(d), !p || is_dead(d), tag(d))
Merge(d1 , d2 ) = r :
r = if is_dead(d1 ) then d2 else d1
Enter(d, frame_name) = r :
value® = value(d)
is_dead® = is_dead(d)
tag® = tag(d)/frame_name/0
Exit(d) = r :
value® = value(d)
is_dead® = is_dead(d)
tag® = tag1 where tag(d) = tag1 /frame_name/n
NextIteration(d) = d1 :
value(d1 ) = value(d)
is_dead(d1 ) = is_dead(d)
tag(d1 ) = tag1 /frame_name/(n+1) where tag(d) = tag1 /frame_name/n
Op(d1 , …, dm ) = (r1 , …, rn ) ://other op
value(r i ) = Op.Compute(value(d1 ), …, value(dm )) if !is_dead(r i )
is_dead(r i ) = any(is_dead(d1 ), … is_dead(dm )), for all i
tag(r i ) = tag(d1 ), for all i
Let’s look at a TF code snippet to calculate the second fibonacci
number
import tensorflow as tf
def cond(a,b,i):
print("from cond")
return i<2
def body(a,b,i):
print("from body")
return (b,a+b,i+1)
r1,r2,r3=tf.while_loop(cond, body,[1,1,1])
with tf.Session() as sess:
writer = tf.summary.FileWriter("logs/", sess.graph)
tf.global_variables_initializer().run()
print(r1.eval())
g=tf.get_default_graph()._as_graph_def()
print(g)
•static information is collected for the frames and nodes.
• A root frame is initialized, and iterationstate for the first iteration of that frame is initialized.
•Nodes with zero input edges are put into a readyqueue
•A tagged node is taken from the readyqueue
•Since the tag of the taggednode (const,0,false) is not dead,a kernel is initialized for the node
•The output of the node is propagated to enter , put in the slot for enter in iteration state 0.
•the pending count of enter is substrated by 1 .the pending count become zero now,and dead count stays zero,hence it is eligible to execute and is put to the readyqueue ,tagged with iteration zero and not dead.
•The same goes for const1 and const2 and is omitted from here.
•Enter is taken from the ready queue
•Since it is not dead,a kernel is created for it.
•The tensor in the slot for enter in iteration state 0 is copyed to an opkernel context object
•Kernel compute with the opkernel context object
•As it is the first time ever the “whileloop” frame was entered,a new FrameState is spawned.
And the first iteration state for “whileloop” is created.
•The output is propagated to Merge in the childframe “whileloop” .
•The lowest bit of pending count for merge stand for if no data has yet come.if it is 1,the input is put in the slot for merge and the bit is cleared,else the input is droped.
•A nextiteration is first meet in iteration 0, a iteration state for iteration 1 is created.
•Iteration count is incremented by 1.
•Dead exits saved are deleted
•The last op for iteration 0 of frame “whileloop” is executed
•The outstanding_op and outstanding_frame_count are both zero , hence iteration 0 is done
•Iteration state for iteration 0 is destroyed
•Num_outstanding _iterations is decreased by 1.
The last op is executed
As the op is a dead next iteration ,it do not propagate to next iteration,iterative computation stops here.
this iteration is done. iteration state is destroyed.
This frame is done,hence destroyed
Parent iteration is done,destroyed too
Parent frame is done,destoryed