TensorFlow c++ SessionFactory注册与No session factory registered错误

TensorFlow c++ SessionFactory注册与No session factory registered错误

背景

近期我们在服务器上使用TensorFlow来进行推理,作为云推理服务的基础。具体如何使用tensorflow c++库来进行推断可以参考之前的记录《从0开始使用tensorflow的c++库进行模型推断》。

在经过一些demo的验证之后,我们开始在项目中正式使用TensorFlow的c++库,简单描述我们的用法:tensorflow作为最底层,上面有一层对推理框架的封装静态库libnn.a,然后是main函数调用这个静态库和tf的动态库完成推理。在这个过程中执行NewSession()时遇到报错: No session factory registered for the given session options:{target: ""} Registered factories are {}

解决办法

参考网上的各种办法,均未解决。

目前我的解决办法是,在main函数的源文件中显示增加对TensorFlow相关头文件的引用,随便include以下头文件之一,然后再编译可执行文件,即可正确加载并运行,当然其他类似文件也可以,不过我只验证了这三个:

“tensorflow/core/public/session.h”
“tensorflow/client/client_session.h”
“tensorflow/core/framework/tensor.h”

问题分析

其实这个才是本文的主要部分。

从字面意思来看,这个报错是因为Session factory注册失败。
好!那我们就从什么是Session factory开始吧!

1. Session Factory

在tensorflow/core/common_runtime/session_factory.h中,我们找到了这个类:

class Session;
struct SessionOptions;
class SessionFactory {
public:
	virtual Status NewSession(const SessionOptions& options,
	                          Session** out_session) = 0;
	virtual bool AcceptsOptions(const SessionOptions& options) = 0;
	virtual Status Reset(const SessionOptions& options,
                       const std::vector<string>& containers) {
	    return errors::Unimplemented("Reset()");
	}
	static void Register(const string& runtime_type, SessionFactory* factory);
	static Status GetFactory(const SessionOptions& options,
                           SessionFactory** out_factory);
}                          

首先在这个头文件开始,声明了两个类Session 以及 SessionOptions,接下来再看成员函数名NewSession, AcceptsOptions, Reset, Register, GetFactory,我们大概可以猜到这个Factory是用来完成SessionFactory的注册,以及创建新的Session,过程中需要使用SessionOptions作为配置。实际上也的确如此。

如果我们在TensorFlow的源码中搜一下No session factory registered...报错,就可以发现这个错误的直接来源就是SessionFactory::GetFactory异常:

Status SessionFactory::GetFactory(const SessionOptions& options,
                                  SessionFactory** out_factory) {
	......
	} else {
    return errors::NotFound(
        "No session factory registered for the given session options: {",
        SessionOptionsToString(options), "} ",
        RegisteredFactoriesErrorMessageLocked());
  }

到此我们知道了报错的具体位置以及报错的表层原因

  1. main函数调用libnn.a中的init函数
  2. init函数调用Status NewSession(const SessionOptions& options, Session** out_session)函数
  3. 调用SessionFactory::GetFactory函数
  4. 报错

2. SessionFactory的注册

前面我们已经了解到报错的表层原因是因为执行GetFactory返回异常,接下来我们仔细分析一下为什么会出现这个异常,我们还是从SessionFactory::GetFactory入手:

Status SessionFactory::GetFactory(const SessionOptions& options,
                                  SessionFactory** out_factory) {
	...
	std::vector<std::pair<string, SessionFactory*>> candidate_factories;
	for (const auto& session_factory : *session_factories()) {
	 if (session_factory.second->AcceptsOptions(options)) {
	   VLOG(2) << "SessionFactory type " << session_factory.first
	           << " accepts target: " << options.target;
	   candidate_factories.push_back(session_factory);
	 } else {
	   VLOG(2) << "SessionFactory type " << session_factory.first
	           << " does not accept target: " << options.target;
	 }
	}
	
	if (candidate_factories.size() == 1) {
	    *out_factory = candidate_factories[0].second;
	    return Status::OK();
	} else if (candidate_factories.size() > 1) {
		 ...
	else {
	    return errors::NotFound(
	        "No session factory registered for the given session options: {",
	        SessionOptionsToString(options), "} ",
	        RegisteredFactoriesErrorMessageLocked());
	}
}

报错异常的直接原因,是因为candidate_factories<1,而这个候选工厂数量,是前面for循环得到的满足条件的factory。

满足什么条件呢?又是从什么集合里面去筛选呢?

我们一个个来分析,首先需要满足的条件是我们在外面定义的const SessionOptions& options,要创建Session,就必须传入一个SessionOptions对象,一般来说,如果不做一些细节调优,我们会在NewSession中传入一个SessionOptions的默认构造对象SessionOptions(),这个默认构造的对象只包含基本的环境变量,对所有factories来说都可以通过AcceptsOptions的判断的。

从什么集合里面去筛选候选Factory呢?在代码里面,就是for循环中的*session_factories(),具体来说就是:

typedef std::unordered_map<string, SessionFactory*> SessionFactories;
SessionFactories* session_factories() {
	static SessionFactories* factories = new SessionFactories;
	return factories;
}

当调用这个函数的时候,返回factories对象,即一个unordered_map,key是string,value是SessionFactory指针,即session_factory.second。

问题又来了,这个unordered_map,又是什么时候被赋值的呢?
在源文件中搜一下insert,就发现了这个函数:

void SessionFactory::Register(const string& runtime_type,
                              SessionFactory* factory) {
  mutex_lock l(*get_session_factory_lock());
  if (!session_factories()->insert({runtime_type, factory}).second) {
    LOG(ERROR) << "Two session factories are being registered "
               << "under" << runtime_type;
  }
}

每次调用Register,都会将一组key为runtime_type,value为factory的键值对放入unordered_map中,由于是map,会自动进行判重,也就是说每个runtime_type对应一个factory。

接下来的问题是,什么时候调用这个Register?
这个不太好找,涉及到TensorFlow的架构设计了,简单来说就是TensorFlow有两种runtime_type:direct和grpc,分别在tensorflow/core/common_runtime/direct_session.cctensorflow/core/common_runtime/grpc_session.cc中实现。

具体来说,是通过定义一个注册机DirectSessionRegistrar类,然后定义全局静态变量registrar,在registrar的初始化中,调用SessionFactory::Register方法,完成对应runtime_type的注册:

class DirectSessionRegistrar {
 public:
  DirectSessionRegistrar() {
    SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
  }
};
static DirectSessionRegistrar registrar;
class GrpcSessionRegistrar {
 public:
  GrpcSessionRegistrar() {
    SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
  }
};
static GrpcSessionRegistrar registrar;

对Session的框架设计有更多的兴趣的,可以参考『深度长文』Tensorflow代码解析(五),这里就不再赘述了。

3. 动态库中全局静态变量的初始化

在上面我们已经分析清楚,这个报错的根本原因是TensorFlow动态库中的这两个源文件中的全局静态变量registrar没有初始化,导致后面在创建Session的时候GetFactory失败。

有一篇文章"Integrating TensorFlow libraries"也分析是这个原因。因为TF组件化的设计思路,一个很小的core+注册的方式来完成包括SessionFactory、Op_Kernel、Op等的加载,在编译的时候,如果这部分代码被编译器忽略了,就会导致部分代码注册失败,进而导致报错。

为什么不初始化呢?
动态库的加载分为显示加载和隐式加载,我们后面验证了显示加载,即显示通过dlopen来调用这些动态库,是可以正常运行的,说明这种情况下,所有的注册都完成了。问题出在隐式加载上。

为什么动态库隐式加载的情况下一些全局静态变量没有初始化呢?
这就是前面说过的解决办法了,隐式加载动态库,标准用法是在可执行文件的代码中显示include动态库export的函数的头文件,然后编译时加上对动态库的链接,这样系统会自动在运行前加载动态库。我们虽然在静态库中include了TF动态库的头文件,但是并未在可执行程序中引用,所以会导致TF动态库的全局静态变量初始化失败!

总结

  1. 使用隐式调用时,则调用方必须要加上动态库中的头文件,g++编译时还需要要用参数-I指明包含的头文件的位置
  2. 进一步通过SessionFactory学习了TensorFlow的模块化设计
  3. 深入Debug是一件很好玩的事情
  4. 千里之堤毁于蚁穴,一定要注意基础中的细节

参考

  1. linux下动态链接库(.so)的显式调用和隐式调用
  2. C语言中的 static变量(全局和局部)、static函数总结
  3. Integrating TensorFlow libraries
  4. 深度长文 Tensorflow代码解析(五)
  5. TensorFlow源码

你可能感兴趣的:(TensorFlow,TensorFlow,动态库,sessionfactory,全局静态变量,c++)