本文主要探索tensorflow python端如何与C++交互,以及创建session,创建图并运行图的整理流程。
先看下面这张图
其中
- client 指的是python代码,分布式环境下每个worker进程里都有一个client
- master 指的是每个worker进程中的master线程
- worker 指的是每个worker进程中的worker线程
- ps 指ps进程
也就是说,每个worker进程中都会有1个master线程和1个worker线程,client和master进行通信,通信的方式为grpc。这是tensorflow分布式环境在 between-graph
模式下的整体架构。
本文主要关注client,也就是python部分。
先看这个
https://www.tensorflow.org/api_docs/python/tf/Session
里面提到源码在 文件 tensorflow/python/client/session.py
其中先定义了一个接口 SessionInterface
,接口主要的方法如下:
@property
def graph(self):
raise NotImplementedError('graph')
@property
def sess_str(self):
raise NotImplementedError('sess_str')
def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
raise NotImplementedError('run')
def partial_run_setup(self, fetches, feeds=None):
raise NotImplementedError('partial_run_setup')
def partial_run(self, handle, fetches, feed_dict=None):
raise NotImplementedError('partial_run')
然后又定义了一个实现 BaseSession
,里面对这个类的注释如下
class BaseSession(SessionInterface):
"""A class for interacting with a TensorFlow computation.
The BaseSession enables incremental graph building with inline
execution of Operations and evaluation of Tensors.
就是说这个BaseSession可以增量构建计算图。
最后是Session
和 InteractiveSession
,此次只看 Session
。
它实现了一个python的上下文管理器,有 __enter__
和 __exit__
方法,当以with
语法运行的时候,上下文管理器会自动管理session对象的创建和销毁,包括关闭session。仔细看它的__init__
方法,其实还是对BaseSession
的简单继承。所以回到BaseSession
。
在BaseSession
的__init__
方法里可以看到,最终调用了TF_NewSession
,而TF_NewSession
又来自 pywrap_tensorflow
。源码如下
from tensorflow.python import pywrap_tensorflow as tf_session
...
...
...
self._session = None
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
try:
if self._created_with_new_api:
# pylint: disable=protected-access
self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
# pylint: enable=protected-access
else:
self._session = tf_session.TF_NewDeprecatedSession(opts)
finally:
tf_session.TF_DeleteSessionOptions(opts)
这里pywrap_tensorflow就是C API,通过swig实现python调用C函数的功能,swig部分在文件 tensorflow/python/client/tf_session.i
中,而C API的代码在如下两个文件。
- tensorflow/c/c_api.h
- tensorflow/c/c_api.cc
看到这里,其实说明tensorflow python部分主要还是调用C++库来实现功能的,启动服务器是C++启动,python接口调用;给服务器发送grpc请求也是调用的C API,而不是直接用python来写grpc,这样也比较合理,既能提高性能,又能减少python client端依赖。
图 Graph
graph部分我主要想了解2个问题:
- 1 图是怎么交给C++引擎执行的?
- 2 图中节点的device信息如何存储?
在tensorflow里,图和会话是密不可分的。训练模型的时候,先定义计算图,然后创建会话,最后利用会话来运行图。详细参考图和会话。关于图的接口在 tf.Graph 里有详细描述,代码在 tensorflow/python/framework/ops.py。
在tensorflow里,所有程序运行的时候,总是会有一张图,这个图包含了op和tensor,op表示图中的节点,tensor表示图中节点间的数据流,这个图就是数据流图。如果用户什么都不做,程序内部会自动创建一张默认的图,把用户当前创建的所有op都添加到这张图里,同时,用户也可以是自己创建一个图并指定为当前程序默认的数据流图。
代码如下:
c = tf.constant(4.0)
assert c.graph is tf.get_default_graph()
或
g = tf.Graph()
with g.as_default():
# Define operations and tensors in `g`.
c = tf.constant(30.0)
assert c.graph is g
为什么一定要表示成数据流图? 最直接的原因就是便于分布式并行执行op。所以这里每个op的位置信息就非常重要了,关于tensorflow如何进行node的placement,后面单独写文章说明。这里假设op的device信息是已知的。
Graph里有2个字典类型的属性,_nodes_by_id和_nodes_by_name,里面看来是保存了op的信息。
self._nodes_by_id = dict() # GUARDED_BY(self._lock)
self._nodes_by_name = dict() # GUARDED_BY(self._lock)
没有发现其他与device相关的信息,所以device信息应该是包含在op内部的。
Operation的API说明在 tf.Operation,定义仍然是在ops.py
里。operation直接有_set_device
方法,该方法被_create_op_helper
和 tf.device
的 _apply_device_functions
调用。 _set_device
内部则调用了 C API c_api.SetRequestedDevice
, 可以看到operation还有一个 device
方法,内部调用了 c_api.TF_OperationDevice
,这两个函数可以用来分别set和get op的device信息,用到的key是 _c_op
,也就是 operation __init__
的第一个参数 node_def
。
继续看 c_api.SetRequestedDevice
,发现定义在
- tensorflow/c/python_api.h
- tensorflow/c/python_api.cc
代码如下
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
mutex_lock l(graph->mu);
op->node.set_requested_device(device);
RecordMutation(graph, *op, "setting device");
}
继续跟进,定义在
- tensorflow/core/graph/graph.h
- tensorflow/core/graph/graph.cc
代码如下
void Node::set_requested_device(const string& device) {
MaybeCopyOnWrite();
props_->node_def.set_device(device);
}
而set_device
定义在 NodeDef
的pb文件里
- tensorflow/core/framework/node_def.proto
如下
message NodeDef {
string name = 1;
string op = 2;
repeated string input = 3;
// A (possibly partial) specification for the device on which this
// node should be placed.
// The expected syntax for this string is as follows:
//
// DEVICE_SPEC ::= PARTIAL_SPEC
//
// PARTIAL_SPEC ::= ("/" CONSTRAINT) *
// CONSTRAINT ::= ("job:" JOB_NAME)
// | ("replica:" [1-9][0-9]*)
// | ("task:" [1-9][0-9]*)
// | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") )
//
// Valid values for this string include:
// * "/job:worker/replica:0/task:1/device:GPU:3" (full specification)
// * "/job:worker/device:GPU:3" (partial specification)
// * "" (no specification)
//
// If the constraints do not resolve to a single device (or if this
// field is empty or not present), the runtime will attempt to
// choose a device automatically.
string device = 4;
map attr = 5;
};
其他几个字段的注释我都去掉了,只看 device 字段。这里面就包含了每个node应该运行在哪个device上的具体信息。从这里可以看出,node的运行位置信息只有通过修改node def才能实现,包含在node定义之内,外部没有单独存储node的device信息。
所以回到一开始的第2个问题,graph中节点的device信息就定义在node内部,随着图一起传给C++执行引擎。
再看第1个问题,先跟踪BaseSession
的执行流程:
run ->_run ->_do_run -> _do_call
_extend_graph -> tf_session.ExtendSession
_call_tf_sessionrun -> tf_session.TF_SessionRun_wrapper
以上,在_do_run
里定义了2个函数,然后把函数传给了_do_call
,最终在_do_call
里执行传入的函数。显然,BaseSession
的增量建图是在_extend_graph
里实现的。
考察graph在整个流程中的变化,首先BaseSession
的__init__
接收一个graph参数,如果这个graph
为None
,就会使用ops.get_default_graph()
,否则使用用户传入的graph。在调用C++端 tf_session.TF_NewSession
的时候会传入graph的一个属性 self._graph._c_graph
,这就是C++端使用的真正的图的定义。
然后,当运行 session.run()
的时候,使用的也就是session
的_graph
属性。这里看2个问题,_c_graph
定义是什么?_extend_graph
内部逻辑?第2个问题比较简单,直接交给了C++端的 tf_session.ExtendSession
来实现。
Graph和Session都是class,在ops.py中找到Graph的定义。是一个被装饰器修饰的属性:
@property
def _c_graph(self):
if self._scoped_c_graph:
return self._scoped_c_graph.graph
return None
对这个属性进行文件内全局搜索发现,在_create_c_op
中有如下代码:
op_desc = c_api.TF_NewOperation(graph._c_graph,
compat.as_str(node_def.op),
compat.as_str(node_def.name))
而_create_c_op
被 类 Operation
的 __init__
方法调用,可见 operation在创建的时候就直接被加入了图graph里,并且是通过TF_NewOperation
直接加入_c_graph
。
由此可窥见tensorflow静态图的整个流程。首先有必须有一个(默认或自定义)数据流图,然后每个op都被直接加入到这个图里,建立session的时候,图会被传给C++端的session,而session.run
会先 _extend_graph
然后 TF_SessionRun_wrapper
,最终得到执行结果。session.run
接收的参数是一个 fetch
,最常见的是一个op。
这个流程也和官方文档描述一致,通过分析源码,对整个流程的执行代码更清晰。
(本文完)