tensorflow源码分析 session client

本文主要探索tensorflow python端如何与C++交互,以及创建session,创建图并运行图的整理流程。


先看下面这张图


引用自 https://github.com/tensorflow/ecosystem

其中

  • client 指的是python代码,分布式环境下每个worker进程里都有一个client
  • master 指的是每个worker进程中的master线程
  • worker 指的是每个worker进程中的worker线程
  • ps 指ps进程

也就是说,每个worker进程中都会有1个master线程和1个worker线程,client和master进行通信,通信的方式为grpc。这是tensorflow分布式环境在 between-graph 模式下的整体架构。

本文主要关注client,也就是python部分。

先看这个
https://www.tensorflow.org/api_docs/python/tf/Session
里面提到源码在 文件 tensorflow/python/client/session.py

其中先定义了一个接口 SessionInterface,接口主要的方法如下:

  @property
  def graph(self):
    raise NotImplementedError('graph')

  @property
  def sess_str(self):
    raise NotImplementedError('sess_str')

  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
    raise NotImplementedError('run')

  def partial_run_setup(self, fetches, feeds=None):
    raise NotImplementedError('partial_run_setup')

  def partial_run(self, handle, fetches, feed_dict=None):
    raise NotImplementedError('partial_run')

然后又定义了一个实现 BaseSession,里面对这个类的注释如下

class BaseSession(SessionInterface):
  """A class for interacting with a TensorFlow computation.

  The BaseSession enables incremental graph building with inline
  execution of Operations and evaluation of Tensors.

就是说这个BaseSession可以增量构建计算图。

最后是SessionInteractiveSession,此次只看 Session

它实现了一个python的上下文管理器,有 __enter____exit__方法,当以with语法运行的时候,上下文管理器会自动管理session对象的创建和销毁,包括关闭session。仔细看它的__init__方法,其实还是对BaseSession的简单继承。所以回到BaseSession

BaseSession__init__方法里可以看到,最终调用了TF_NewSession,而TF_NewSession 又来自 pywrap_tensorflow。源码如下

from tensorflow.python import pywrap_tensorflow as tf_session
...
...
...
    self._session = None
    opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
    try:
      if self._created_with_new_api:
        # pylint: disable=protected-access
        self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
        # pylint: enable=protected-access
      else:
        self._session = tf_session.TF_NewDeprecatedSession(opts)
    finally:
      tf_session.TF_DeleteSessionOptions(opts)

这里pywrap_tensorflow就是C API,通过swig实现python调用C函数的功能,swig部分在文件 tensorflow/python/client/tf_session.i中,而C API的代码在如下两个文件。

  • tensorflow/c/c_api.h
  • tensorflow/c/c_api.cc

看到这里,其实说明tensorflow python部分主要还是调用C++库来实现功能的,启动服务器是C++启动,python接口调用;给服务器发送grpc请求也是调用的C API,而不是直接用python来写grpc,这样也比较合理,既能提高性能,又能减少python client端依赖。

图 Graph

graph部分我主要想了解2个问题:

  • 1 图是怎么交给C++引擎执行的?
  • 2 图中节点的device信息如何存储?

在tensorflow里,图和会话是密不可分的。训练模型的时候,先定义计算图,然后创建会话,最后利用会话来运行图。详细参考图和会话。关于图的接口在 tf.Graph 里有详细描述,代码在 tensorflow/python/framework/ops.py。

在tensorflow里,所有程序运行的时候,总是会有一张图,这个图包含了op和tensor,op表示图中的节点,tensor表示图中节点间的数据流,这个图就是数据流图。如果用户什么都不做,程序内部会自动创建一张默认的图,把用户当前创建的所有op都添加到这张图里,同时,用户也可以是自己创建一个图并指定为当前程序默认的数据流图。

代码如下:

  c = tf.constant(4.0)
  assert c.graph is tf.get_default_graph()

  g = tf.Graph()
  with g.as_default():
    # Define operations and tensors in `g`.
    c = tf.constant(30.0)
    assert c.graph is g

为什么一定要表示成数据流图? 最直接的原因就是便于分布式并行执行op。所以这里每个op的位置信息就非常重要了,关于tensorflow如何进行node的placement,后面单独写文章说明。这里假设op的device信息是已知的。

Graph里有2个字典类型的属性,_nodes_by_id和_nodes_by_name,里面看来是保存了op的信息。

    self._nodes_by_id = dict()  # GUARDED_BY(self._lock)
    self._nodes_by_name = dict()  # GUARDED_BY(self._lock)

没有发现其他与device相关的信息,所以device信息应该是包含在op内部的。

Operation的API说明在 tf.Operation,定义仍然是在ops.py里。operation直接有_set_device方法,该方法被_create_op_helpertf.device_apply_device_functions调用。 _set_device 内部则调用了 C API c_api.SetRequestedDevice, 可以看到operation还有一个 device方法,内部调用了 c_api.TF_OperationDevice,这两个函数可以用来分别set和get op的device信息,用到的key是 _c_op,也就是 operation __init__的第一个参数 node_def

继续看 c_api.SetRequestedDevice,发现定义在

  • tensorflow/c/python_api.h
  • tensorflow/c/python_api.cc

代码如下

void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
  mutex_lock l(graph->mu);
  op->node.set_requested_device(device);
  RecordMutation(graph, *op, "setting device");
}

继续跟进,定义在

  • tensorflow/core/graph/graph.h
  • tensorflow/core/graph/graph.cc

代码如下

void Node::set_requested_device(const string& device) {
  MaybeCopyOnWrite();
  props_->node_def.set_device(device);
}

set_device 定义在 NodeDef 的pb文件里

  • tensorflow/core/framework/node_def.proto

如下

message NodeDef {
  string name = 1;
  string op = 2;
  repeated string input = 3;

  // A (possibly partial) specification for the device on which this
  // node should be placed.
  // The expected syntax for this string is as follows:
  //
  // DEVICE_SPEC ::= PARTIAL_SPEC
  //
  // PARTIAL_SPEC ::= ("/" CONSTRAINT) *
  // CONSTRAINT ::= ("job:" JOB_NAME)
  //              | ("replica:" [1-9][0-9]*)
  //              | ("task:" [1-9][0-9]*)
  //              | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") )
  //
  // Valid values for this string include:
  // * "/job:worker/replica:0/task:1/device:GPU:3"  (full specification)
  // * "/job:worker/device:GPU:3"                   (partial specification)
  // * ""                                    (no specification)
  //
  // If the constraints do not resolve to a single device (or if this
  // field is empty or not present), the runtime will attempt to
  // choose a device automatically.
  string device = 4;
  map attr = 5;
};

其他几个字段的注释我都去掉了,只看 device 字段。这里面就包含了每个node应该运行在哪个device上的具体信息。从这里可以看出,node的运行位置信息只有通过修改node def才能实现,包含在node定义之内,外部没有单独存储node的device信息。

所以回到一开始的第2个问题,graph中节点的device信息就定义在node内部,随着图一起传给C++执行引擎。

再看第1个问题,先跟踪BaseSession 的执行流程:

run ->_run ->_do_run -> _do_call
_extend_graph -> tf_session.ExtendSession
_call_tf_sessionrun -> tf_session.TF_SessionRun_wrapper

以上,在_do_run里定义了2个函数,然后把函数传给了_do_call,最终在_do_call里执行传入的函数。显然,BaseSession的增量建图是在_extend_graph里实现的。

考察graph在整个流程中的变化,首先BaseSession__init__接收一个graph参数,如果这个graphNone,就会使用ops.get_default_graph(),否则使用用户传入的graph。在调用C++端 tf_session.TF_NewSession 的时候会传入graph的一个属性 self._graph._c_graph,这就是C++端使用的真正的图的定义。

然后,当运行 session.run()的时候,使用的也就是session_graph属性。这里看2个问题,_c_graph定义是什么?_extend_graph内部逻辑?第2个问题比较简单,直接交给了C++端的 tf_session.ExtendSession来实现。

Graph和Session都是class,在ops.py中找到Graph的定义。是一个被装饰器修饰的属性:

  @property
  def _c_graph(self):
    if self._scoped_c_graph:
      return self._scoped_c_graph.graph
    return None

对这个属性进行文件内全局搜索发现,在_create_c_op中有如下代码:

  op_desc = c_api.TF_NewOperation(graph._c_graph,
                                  compat.as_str(node_def.op),
                                  compat.as_str(node_def.name))

_create_c_op 被 类 Operation__init__方法调用,可见 operation在创建的时候就直接被加入了图graph里,并且是通过TF_NewOperation直接加入_c_graph

由此可窥见tensorflow静态图的整个流程。首先有必须有一个(默认或自定义)数据流图,然后每个op都被直接加入到这个图里,建立session的时候,图会被传给C++端的session,而session.run会先 _extend_graph 然后 TF_SessionRun_wrapper,最终得到执行结果。session.run 接收的参数是一个 fetch,最常见的是一个op。

这个流程也和官方文档描述一致,通过分析源码,对整个流程的执行代码更清晰。

(本文完)

你可能感兴趣的:(tensorflow源码分析 session client)