TensorFlow技术内幕(四):TF中的混合编程

本章的主题是TF中的混合编程,以Python与C/C++混合编程为例.

按进度来说,现在应该写点TF使用教程,让大家熟悉一下tensorflow的使用,但是我发现现在这方面的资料和书籍已经很多了,这里就不再赘述了,毕竟时间有限,留给更有意义的事情。

做到熟悉TF使用的最好的方式就是动手实践具体的例子,官网提供的教程就不错。我建议继续阅读本章之前,读者通过这些实际操作的例子,熟悉一下TensorFlow的使用流程。

下面进入本章的主题。

连接两个世界的传送门

回忆上一章中,我们编译、安装tensorflow的方式如下:

首先,bazel build目标

$ bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package

然后,启动目标程序,生成wheel格式安装包到tmp/tensorflow_pkg目录:

$ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg

最后,通过tensorflow的pip包管理器安装tensorflow的安装包:

$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.6.0-py2-none-any.whl

这个流程我相信大家都已经很熟悉了,我们就从这个顶级的构建目标入手,自顶向下分析TF中Python的调用是如何进入C/C++世界的。

第二章中我们提到,上面命令中的构建目标 //tensorflow/tools/pip_package:build_pip_package 其实是一个可执行的 shell 脚本,通过 bazel 的 sh_binary 规则生成,sh_binary 的规则会将脚本依赖的文件拷贝或则生成(如果被依赖项也是bazel的规则定义的目标的话)到runfiles目录下,脚本build_pip_package.sh的工作就是将这些文件打包成一个wheel格式的安装包。

第一章中,我们学习到tf的内核是由C++写成的,支持的前端API有Python,Go,Java,这些前端语言接口基本都是对C API的封装;那么我们把TF的整个工程分为两个部门来分别学习:接口部分和内核部分;我们还知道,接口部分通往内核部分的最终都会通过//tensorflow/c:c_api。

因此,我们来看下Python的顶级目标与c_api的关系。这里需要用到了 bazel 的query命令。

我们在tf工程根目录执行下面的命令:

$ bazel query  'allpaths(//tensorflow/tools/pip_package:build_pip_package, //tensorflow/c:c_api)' --output graph | dot -v -Tpng -o dep_paths.in

命令的作用是找到bazel目标 //tensorflow/tools/pip_package:build_pip_package 到目标 //tensorflow/c:c_api 的所有依赖路径,并将结果输出为图片:

这里写图片描述

图3:build_pip_package到c_api的所以依赖路径

这用到了graphviz的dot命令,graphviz的安装也很简单:

$ sudo apt-get install graphviz

图3是一个有向图,每一条有向边代表源节点对目的节点的依赖关系,可以看到涉及的工程目标非常多,依赖也繁杂,但是还是可以分析的;首先,我们注意这两个节点:其中一个节点只有出度没有入度,这就是我们都顶级构建目标 //tensorflow/tools/pip_package:build_pip_package

TensorFlow技术内幕(四):TF中的混合编程_第1张图片

图4:build_pip_package 节点

另一个节点只有入度,没有出度,就是我们的 //tensorflow/c:c_api 节点:

TensorFlow技术内幕(四):TF中的混合编程_第2张图片

图5:c_api节点

另外,还有一个节点比较引人注意,那就是 //tensorflow/python:pywrap_tensorflow_internal,我们注意到这样一个特征,整张图在这个节点上分成了上下两个”团体”,每个团体内部的依赖比较复杂,暂时先不用管,但是上层”团体”对下层”团体”的依赖都经过//tensorflow/python:pywrap_tensorflow_internal节点:

TensorFlow技术内幕(四):TF中的混合编程_第3张图片

图6:pywrap_tensorflow_internal

直观感觉,这个节点至关重要,是连接了两个世界的”传送门”。

Python扩展pywrap_tensorflow_internal

找到定义pywrap_tensorflow_internal的BUILD文件//tensorflow/python/BUILD:

tf_py_wrap_cc(
    name = "pywrap_tensorflow_internal",
    srcs = ["tensorflow.i"],
    swig_includes = [
        "client/device_lib.i",
        "client/events_writer.i",
        "client/tf_session.i",
        "client/tf_sessionrun_wrapper.i",
        "framework/cpp_shape_inference.i",
        "framework/python_op_gen.i",
        "grappler/cost_analyzer.i",
        "grappler/tf_optimizer.i",
        "lib/core/py_func.i",
        "lib/core/strings.i",
        "lib/io/file_io.i",
        "lib/io/py_record_reader.i",
        "lib/io/py_record_writer.i",
        "platform/base.i",
        "training/quantize_training.i",
        "training/server_lib.i",
        "util/kernel_registry.i",
        "util/port.i",
        "util/py_checkpoint_reader.i",
        "util/stat_summarizer.i",
        "util/transform_graph.i",
    ],
    deps = [
        ":cost_analyzer_lib",
        ":cpp_shape_inference",
        ":kernel_registry",
        ":numpy_lib",
        ":py_func_lib",
        ":py_record_reader_lib",
        ":py_record_writer_lib",
        ":python_op_gen",
        ":tf_session_helper",
        "//tensorflow/c:c_api",
        "//tensorflow/c:checkpoint_reader",
        "//tensorflow/c:tf_status_helper",
        "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
        "//tensorflow/core/distributed_runtime/rpc:grpc_session",
        "//tensorflow/core/grappler:grappler_item",
        "//tensorflow/core/grappler:grappler_item_builder",
        "//tensorflow/core/grappler/clusters:single_machine",
        "//tensorflow/core/grappler/optimizers:meta_optimizer",
        "//tensorflow/core:lib",
        "//tensorflow/core:reader_base",
        "//tensorflow/core/debug",
        "//tensorflow/core/distributed_runtime:server_lib",
        "//tensorflow/tools/graph_transforms:transform_graph_lib",
        "//tensorflow/tools/tfprof/internal:print_model_analysis",
        "//util/python:python_headers",
    ] + (tf_additional_lib_deps() +
         tf_additional_plugin_deps() +
         tf_additional_verbs_deps() +
         tf_additional_mpi_deps()),
)

这里的tf_py_wrap_cc是bazel的宏函数。那么宏函数tf_py_wrap_cc的作用是什么?这里srcs和swig_includes属性里的扩展名为 .i 文件又是什么呢?

找到宏函数tf_py_wrap_cc定义的文件 /tensorflow/tensorflow.bzl :

def tf_py_wrap_cc(name,
                             srcs,
                             swig_includes=[],
                             deps=[],
                             copts=[],
                             **kwargs):
  module_name = name.split("/")[-1]
  # Convert a rule name such as foo/bar/baz to foo/bar/_baz.so
  # and use that as the name for the rule producing the .so file.
  cc_library_name = "/".join(name.split("/")[:-1] + ["_" + module_name + ".so"])
  cc_library_pyd_name = "/".join(
      name.split("/")[:-1] + ["_" + module_name + ".pyd"])
  extra_deps = []
  _py_wrap_cc(
      name=name + "_py_wrap",
      srcs=srcs,
      swig_includes=swig_includes,
      deps=deps + extra_deps,
      toolchain_deps=["//tools/defaults:crosstool"],
      module_name=module_name,
      py_module_name=name)
   ... 
   ...

函数的前半部分,主要就是调用一个自定义规则 _py_wrap_cc,此规则声明同在tensorflow/tensorflow.bzl文件中,内容如下:

_py_wrap_cc = rule(
    #
    # 定义规则的输入属性名称,数据类型,是否必须以及默认值等
    #
    attrs={
        "srcs":
            attr.label_list(
                mandatory=True,
                allow_files=True,),
        "swig_includes":
            attr.label_list(
                cfg="data",
                allow_files=True,),
        "deps":
            attr.label_list(
                allow_files=True,
                providers=["cc"],),
        "toolchain_deps":
            attr.label_list(
                allow_files=True,),
        "module_name":
            attr.string(mandatory=True),
        "py_module_name":
            attr.string(mandatory=True),
        "_swig":
            attr.label(
                default=Label("@swig//:swig"),
                executable=True,
                cfg="host",),
        "_swiglib":
            attr.label(
                default=Label("@swig//:templates"),
                allow_files=True,),
    },

    #
    # 定义规则的输出
    #
    outputs={
        "cc_out": "%{module_name}.cc",
        "py_out": "%{py_module_name}.py",
    },

    #
    # 定义规则的实现函数
    #
    implementation=_py_wrap_cc_impl,)

规则声明中,定义了规则的属性、属性的变量类型以及默认取值、规则的输出以及实现函数;_py_wrap_cc规则的实现在函数_py_wrap_cc_impl中,我们将仔细分析一下这个函数。

注:宏函数和自定义规则是bazel支持的两个扩展机制,两则是有差别的,
限于篇幅这里就不仔细介绍了。可以参考
[官网](https://docs.bazel.build/versions/master/skylark/concepts.html)。
暂时读者只要能更随本人思路就可以,细节可以之后再去学习,本人尽力保证在读者不熟悉bazel
的情况下也能理解本文内容。

在详细分析_py_wrap_cc_impl函数之前,我们需要补充一点关于Python和C/C++混合编程的知识。

SWIG

TensorFlow技术内幕(四):TF中的混合编程_第4张图片

图7:Python与C/C++混合编程的两种模式

Python和C/C++的混合编程存在两种编程模式:扩展与嵌入,这里主要介绍前一种。C/C++编写Pyton扩展过程如下:

TensorFlow技术内幕(四):TF中的混合编程_第5张图片

图8:第一步、编写封装函数

TensorFlow技术内幕(四):TF中的混合编程_第6张图片

图9:第二步、编写模块初始话函数

TensorFlow技术内幕(四):TF中的混合编程_第7张图片

图10:完整的扩展例子

可以看到,手动完成扩展的编写还是挺低效的,最后还需要将编写完的封装代码和源码一起编译成动态链接库,限于篇幅,这里就不具体介绍了;我们来介绍一个自动化完成扩展编写的工具SWIG。

TensorFlow技术内幕(四):TF中的混合编程_第8张图片

图11:swig简介

SWIG是一个接口编译工具,连接C/C++代码与脚本语言Perl.Python,Ruby,Tcl的桥梁。SWIG为C/C++头文件自动生成包装代码,提供给脚本语言使用。

我们来看一个例子,假如我们有这样一个C代码文件example.c,其中包含了想要提供给其他语言如Perl,Python,java,C#代码使用的方法:

/* File : example.c */

 #include 
 double My_variable = 3.0;

 int fact(int n) {
     if (n <= 1) return 1;
     else return n*fact(n-1);
 }

 int my_mod(int x, int y) {
     return (x%y);
 }

 char *get_time()
 {
     time_t ltime;
     time(<ime);
     return ctime(<ime);
 }

那么首先我们需要创建一个”接口文件”,扩展名为 .i:

/* example.i */
%module example
%{
/* Put header files here or function declarations like below */
extern double  My_variable;
extern int fact(int n);
extern int my_mod(int x, int y);
extern char *get_time();
%}

extern double My_variable;
extern int fact(int n);
extern int my_mod(int x, int y);
extern char *get_time();

然后就可以构建其他语言的模块了,例如可以执行如下命令构建Python模块:

$ swig -python example.i
$ gcc -c example.c example_wrap.c -I/usr/local/include/python2.1
$ ld -shared example.o example_wrap.o -o _example.so

然后就可以调用生成的Python模块了:

>>> import example
>>> example.fact(5)
120
>>> example.my_mod(7,3)
1
>>> example.get_time()
'Sun Feb 11 23:01:07 1996'
>>>

可以通过执行下列命令生成Java模块:

$ swig -java example.i
$ gcc -c example.c example_wrap.c -I/c/jdk1.3.1/include -I/c/jdk1.3.1/include/win32
$ gcc -shared example.o example_wrap.o -mno-cygwin -Wl, --add-stdcall-alias -o example.dll

然后可以编写java代码,调用此模块:

/* file main.java */

public class main{
    public static void main(String argv[]){
        System.loadLobrary('example');
        System.out.println(example.getMy_variable());
        System.out.println(example.fact(5));
        System.out.println(example.get_time());
    }
}

最后执行调用:

$ javac main.java
$ java main
3.0
120
Mon Mar  4 18:20:31  2002
$

有了这些准备知识后,我们可以开始分析_py_wrap_cc_impl函数了:

# Bazel rules for building swig files.
def _py_wrap_cc_impl(ctx):

  ##
  ## 下面的代码在构造SWIG的参数
  ##
  srcs = ctx.files.srcs
  if len(srcs) != 1:
    fail("Exactly one SWIG source file label must be specified.", "srcs")
  module_name = ctx.attr.module_name
  src = ctx.files.srcs[0]
  inputs = set([src])
  inputs += ctx.files.swig_includes
  for dep in ctx.attr.deps:
    inputs += dep.cc.transitive_headers
  inputs += ctx.files._swiglib
  inputs += ctx.files.toolchain_deps
  swig_include_dirs = set(_get_repository_roots(ctx, inputs))
  swig_include_dirs += sorted([f.dirname for f in ctx.files._swiglib])

  ##
  ## swig的命令行参数:-c++表示启动C++解析,-python表示输出python的
  ## wrapper代码,-module设置模块名称,-o表示输出文件名称,-outdir
  ## 表示输出目录路径,-l表示包含的SWIG的库文件名称,包括用户提供的.i文件
  ## 以及需要的SWIG库文件(也是一些.i文件),-I表示把参数路径添
  ## 加到SWIG的include查找路径。
  ## 
  args = [
      "-c++", "-python", "-module", module_name, "-o", ctx.outputs.cc_out.path,
      "-outdir", ctx.outputs.py_out.dirname
  ]
  args += ["-l" + f.path for f in ctx.files.swig_includes]
  args += ["-I" + i for i in swig_include_dirs]
  args += [src.path]
  outputs = [ctx.outputs.cc_out, ctx.outputs.py_out]

  ##
  ## 调用SWIG命令,生成swig files:ctx.action函数会启动一个
  ## 可执行文件或脚本executable, 启动参数arguments, inputs
  ## 表示表示输入文件,outputs表示输出文件。
  ##
  ctx.action(
      executable=ctx.executable._swig,
      arguments=args,
      inputs=list(inputs),
      outputs=outputs,
      mnemonic="PythonSwig",
      progress_message="SWIGing " + src.path)
  return struct(files=set(outputs))

C/C++的Python插件封装代码通过SWIG生成之后,就可以编译Python的插件了,这就是函数tf_py_wrap_cc后半部分所完成的工作:

def tf_py_wrap_cc(name,
                             srcs,
                             swig_includes=[],
                             deps=[],
                             copts=[],
                             **kwargs):
    ...
    ...
    ## 
    ## module_name + ".cc"是上面介绍的_py_wrap_cc规则的输出文件
    ## 也就是C/C++代码的Python封装代码,与dep中的C/C++代码一起,
    ## 通过cc_binary规则,生成动态链接库cc_library_name。
    ##
    native.cc_binary(
      name=cc_library_name,
      srcs=[module_name + ".cc"],
      copts=(copts + [
          "-Wno-self-assign", "-Wno-sign-compare", "-Wno-write-strings"
      ] + tf_extension_copts()),
      linkopts=tf_extension_linkopts() + extra_linkopts,
      linkstatic=1,
      linkshared=1,
      deps=deps + extra_deps)

  ##
  ## 定义一个生成规则,执行拷贝命令cp,将动态链接库cc_library_name拷贝一份,
  ## 名称为cc_library_pyd_name(python的pyd文件实际也就是windows平台下动态链接库,
  ## 只不过扩展名不一样而已)
  ##
  native.genrule(
      name="gen_" + cc_library_pyd_name,
      srcs=[":" + cc_library_name],
      outs=[cc_library_pyd_name],
      cmd="cp $< $@",)

  ##
  ## py_library规则定义一个python库目标,如果是windows平台下
  ## 则依赖.pyd文件cc_library_pyd_name,这会出发上面的拷贝动作,其他
  ## 平台下,则依赖.so文件cc_library_name
  native.py_library(
      name=name,
      srcs=[":" + name + ".py"],
      srcs_version="PY2AND3",
      data=select({
          clean_dep("//tensorflow:windows"): [":" + cc_library_pyd_name],
          "//conditions:default": [":" + cc_library_name],
      }))

最后,执行bazel build命令,在windows下,会生成下列文件:

pywrap_tensorflow_internal.cc
pywrap_tensorflow_internal.py
_pywrap_tensorflow_internal.pyd

而在非window平台下,则生成下列文件:

pywrap_tensorflow_internal.cc
pywrap_tensorflow_internal.py
_pywrap_tensorflow_internal.so

小结

总结一下,本章中,我们通过工具bazel query,找到了混合编程中链接两个世界的模块pywrap_tensorflow_internal,实际上它就是Python的一个扩展,python的代码通过这个扩展就可以调用底层的C/C++代码了。然后分析此工程的过程中,引入了SWIG工具,它使得C/C++代码很方便的导出到各种其他的脚本语言。

你可能感兴趣的:(TensorFlow技术内幕(四):TF中的混合编程)