近期我们在服务器上使用TensorFlow来进行推理,作为云推理服务的基础。具体如何使用tensorflow c++库来进行推断可以参考之前的记录《从0开始使用tensorflow的c++库进行模型推断》。
在经过一些demo的验证之后,我们开始在项目中正式使用TensorFlow的c++库,简单描述我们的用法:tensorflow作为最底层,上面有一层对推理框架的封装静态库libnn.a,然后是main函数调用这个静态库和tf的动态库完成推理。在这个过程中执行NewSession()时遇到报错: No session factory registered for the given session options:{target: ""} Registered factories are {}
。
参考网上的各种办法,均未解决。
目前我的解决办法是,在main函数的源文件中显示增加对TensorFlow相关头文件的引用,随便include以下头文件之一,然后再编译可执行文件,即可正确加载并运行,当然其他类似文件也可以,不过我只验证了这三个:
“tensorflow/core/public/session.h”
“tensorflow/client/client_session.h”
“tensorflow/core/framework/tensor.h”
其实这个才是本文的主要部分。
从字面意思来看,这个报错是因为Session factory注册失败。
好!那我们就从什么是Session factory开始吧!
在tensorflow/core/common_runtime/session_factory.h中,我们找到了这个类:
class Session;
struct SessionOptions;
class SessionFactory {
public:
virtual Status NewSession(const SessionOptions& options,
Session** out_session) = 0;
virtual bool AcceptsOptions(const SessionOptions& options) = 0;
virtual Status Reset(const SessionOptions& options,
const std::vector<string>& containers) {
return errors::Unimplemented("Reset()");
}
static void Register(const string& runtime_type, SessionFactory* factory);
static Status GetFactory(const SessionOptions& options,
SessionFactory** out_factory);
}
首先在这个头文件开始,声明了两个类Session 以及 SessionOptions,接下来再看成员函数名NewSession, AcceptsOptions, Reset, Register, GetFactory
,我们大概可以猜到这个Factory是用来完成SessionFactory的注册,以及创建新的Session,过程中需要使用SessionOptions作为配置。实际上也的确如此。
如果我们在TensorFlow的源码中搜一下No session factory registered...
报错,就可以发现这个错误的直接来源就是SessionFactory::GetFactory异常:
Status SessionFactory::GetFactory(const SessionOptions& options,
SessionFactory** out_factory) {
......
} else {
return errors::NotFound(
"No session factory registered for the given session options: {",
SessionOptionsToString(options), "} ",
RegisteredFactoriesErrorMessageLocked());
}
到此我们知道了报错的具体位置以及报错的表层原因
:
- main函数调用libnn.a中的init函数
- init函数调用Status NewSession(const SessionOptions& options, Session** out_session)函数
- 调用SessionFactory::GetFactory函数
- 报错
前面我们已经了解到报错的表层原因是因为执行GetFactory返回异常,接下来我们仔细分析一下为什么会出现这个异常,我们还是从SessionFactory::GetFactory入手:
Status SessionFactory::GetFactory(const SessionOptions& options,
SessionFactory** out_factory) {
...
std::vector<std::pair<string, SessionFactory*>> candidate_factories;
for (const auto& session_factory : *session_factories()) {
if (session_factory.second->AcceptsOptions(options)) {
VLOG(2) << "SessionFactory type " << session_factory.first
<< " accepts target: " << options.target;
candidate_factories.push_back(session_factory);
} else {
VLOG(2) << "SessionFactory type " << session_factory.first
<< " does not accept target: " << options.target;
}
}
if (candidate_factories.size() == 1) {
*out_factory = candidate_factories[0].second;
return Status::OK();
} else if (candidate_factories.size() > 1) {
...
else {
return errors::NotFound(
"No session factory registered for the given session options: {",
SessionOptionsToString(options), "} ",
RegisteredFactoriesErrorMessageLocked());
}
}
报错异常的直接原因,是因为candidate_factories<1,而这个候选工厂数量,是前面for循环得到的满足条件的factory。
满足什么条件呢?又是从什么集合里面去筛选呢?
我们一个个来分析,首先需要满足的条件是我们在外面定义的const SessionOptions& options,要创建Session,就必须传入一个SessionOptions对象,一般来说,如果不做一些细节调优,我们会在NewSession中传入一个SessionOptions的默认构造对象SessionOptions(),这个默认构造的对象只包含基本的环境变量,对所有factories来说都可以通过AcceptsOptions的判断的。
从什么集合里面去筛选候选Factory呢?在代码里面,就是for循环中的*session_factories(),具体来说就是:
typedef std::unordered_map<string, SessionFactory*> SessionFactories;
SessionFactories* session_factories() {
static SessionFactories* factories = new SessionFactories;
return factories;
}
当调用这个函数的时候,返回factories对象,即一个unordered_map,key是string,value是SessionFactory指针,即session_factory.second。
问题又来了,这个unordered_map,又是什么时候被赋值的呢?
在源文件中搜一下insert
,就发现了这个函数:
void SessionFactory::Register(const string& runtime_type,
SessionFactory* factory) {
mutex_lock l(*get_session_factory_lock());
if (!session_factories()->insert({runtime_type, factory}).second) {
LOG(ERROR) << "Two session factories are being registered "
<< "under" << runtime_type;
}
}
每次调用Register,都会将一组key为runtime_type,value为factory的键值对放入unordered_map中,由于是map,会自动进行判重,也就是说每个runtime_type对应一个factory。
接下来的问题是,什么时候调用这个Register?
这个不太好找,涉及到TensorFlow的架构设计了,简单来说就是TensorFlow有两种runtime_type:direct和grpc,分别在tensorflow/core/common_runtime/direct_session.cc
和 tensorflow/core/common_runtime/grpc_session.cc
中实现。
具体来说,是通过定义一个注册机DirectSessionRegistrar类,然后定义全局静态变量registrar,在registrar的初始化中,调用SessionFactory::Register方法,完成对应runtime_type的注册:
class DirectSessionRegistrar {
public:
DirectSessionRegistrar() {
SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
}
};
static DirectSessionRegistrar registrar;
class GrpcSessionRegistrar {
public:
GrpcSessionRegistrar() {
SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
}
};
static GrpcSessionRegistrar registrar;
对Session的框架设计有更多的兴趣的,可以参考『深度长文』Tensorflow代码解析(五),这里就不再赘述了。
在上面我们已经分析清楚,这个报错的根本原因是TensorFlow动态库中的这两个源文件中的全局静态变量registrar没有初始化,导致后面在创建Session的时候GetFactory失败。
有一篇文章"Integrating TensorFlow libraries"也分析是这个原因。因为TF组件化的设计思路,一个很小的core+注册的方式来完成包括SessionFactory、Op_Kernel、Op等的加载,在编译的时候,如果这部分代码被编译器忽略了,就会导致部分代码注册失败,进而导致报错。
为什么不初始化呢?
动态库的加载分为显示加载和隐式加载,我们后面验证了显示加载,即显示通过dlopen来调用这些动态库,是可以正常运行的,说明这种情况下,所有的注册都完成了。问题出在隐式加载上。
为什么动态库隐式加载的情况下一些全局静态变量没有初始化呢?
这就是前面说过的解决办法了,隐式加载动态库,标准用法是在可执行文件的代码中显示include动态库export的函数的头文件,然后编译时加上对动态库的链接,这样系统会自动在运行前加载动态库。我们虽然在静态库中include了TF动态库的头文件,但是并未在可执行程序中引用,所以会导致TF动态库的全局静态变量初始化失败!