使用工厂模式,根据配置文件动态地生成对象。
简单工厂模式由一个工厂对象决定创建出哪一种产品类的实例。
代码位置Apollo/modules/perception/inference
用来做深度学习推理的框架CaffeNet\RTNet\PaddleNet都继承自Inference基类。
Inference *CreateInferenceByName(const std::string &name,
const std::string &proto_file,
const std::string &weight_file,
const std::vector &outputs,
const std::vector &inputs,
const std::string &model_root) {
if (name == "CaffeNet") {
return new CaffeNet(proto_file, weight_file, outputs, inputs);
} else if (name == "RTNet") {
return new RTNet(proto_file, weight_file, outputs, inputs);
} else if (name == "RTNetInt8") {
return new RTNet(proto_file, weight_file, outputs, inputs, model_root);
} else if (name == "PaddleNet") {
return new PaddleNet(proto_file, weight_file, outputs, inputs);
}
return nullptr;
}
Client使用示例:Apollo/modules/perception/camera/lib/obstacle/detector/yolo/yolo_obstacle_detector中:
std::shared_ptr inference_; //yolo_obstacle_detector.h
inference_.reset(inference::CreateInferenceByName(model_type, proto_file,weight_file, output_names,input_names, model_root)); //yolo_obstacle_detector.cc
在使用示例中,/home/yly/pytorch_model/Apollo/modules/perception/camera/test/camera_lib_obstacle_detector_yolo_yolo_obstacle_detector_test.cc读取配置文件/Apollo/modules/perception/testdata/camera/lib/obstacle/detector/yolo/data/config.pt中model_type参数,发现是"RTNet",从而生成new RTNet(proto_file, weight_file, outputs, inputs);
优点:实现方便,适合负责创建的对象比较少的场景。在Inference中,推理框架一般不会频繁增加,因此使用简单工厂模式非常合适。
缺点:由于工厂类集中了所有实例的创建逻辑,它所能创建的类只能是事先考虑到的,如果需要添加新的类,则就需要改变工厂类了,破坏“开闭原则”。
工厂方法定义一个抽象工厂,由抽象工厂负责定义产品的生产接口,但不负责生产具体的产品,将生产任务交给不同的派生类工厂。这样不用通过指定具体类型来创建对象了(使用基类名称即可)。
优点:工厂方法模式实现“开闭原则”,保证可扩展性。
代码位置/home/yly/pytorch_model/Apollo/modules/perception/camera模块。
在https://blog.csdn.net/Cxiazaiyu/article/details/106256330中我们讲到Camera模块做了非常好的架构设计,interface文件夹中定义了功能接口,如BaseObstacleDetector(Apollo/modules/perception/camera/lib/interface/base_obstacle_detector.h) ,在app中(Apollo/modules/perception/camera/app/obstacle_camera_perception.cc)中使用基类接口创建对象。
Client使用示例:通过ObjectFactory的指针使用创建对象的方法,返回的是BaseObstacleDetector类的指针。
std::shared_ptr detector_ptr(BaseObstacleDetectorRegisterer::GetInstanceByName(plugin_param.name()));
实际创建的对象则是根据配置文件中载入的参数动态地创建对象,例如Apollo/modules/perception/camera/test/camera_app_obstacle_camera_perception_test.cc中调用配置文件/Apollo/modules/perception/testdata/camera/app/conf/perception/camera/obstacle/obstacle.pt,
detector_param {
plugin_param{
name: "YoloObstacleDetector"
root_dir: "/apollo/modules/perception/production/data/perception/camera/models/yolo_obstacle_detector"
config_file: "config.pt"
}
camera_name : "front_6mm"
}
从而,根据配置文件确定创建BaseObstacleDetector的派生类YoloObstacleDetector的实例。
class ObjectFactory {
public:
ObjectFactory() {}
virtual ~ObjectFactory() {}
virtual Any NewInstance() { return Any(); }
ObjectFactory(const ObjectFactory &) = delete;
ObjectFactory &operator=(const ObjectFactory &) = delete;
private:
};
在派生类定义中(/Apollo/modules/perception/camera/lib/obstacle/detector/yolo/yolo_obstacle_detector.cc),根据抽象工厂派生具体工厂:
class ObjectFactoryYoloObstacleDetector:public apollo::perception::lib::ObjectFactory {
public:
virtual ~ObjectFactoryYoloObstacleDetector() {}
virtual ::apollo::perception::lib::Any NewInstance() {
return ::apollo::perception::lib::Any(new YoloObstacleDetector());
}
};
在具体工厂中创建具体产品YoloObstacleDetector.
备注:以上为REGISTER_OBSTACLE_DETECTOR(YoloObstacleDetector)在Apollo/modules/perception/camera/lib/interface/base_obstacle_detector.h中的宏定义#define REGISTER_OBSTACLE_DETECTOR(name) \
PERCEPTION_REGISTER_CLASS(BaseObstacleDetector, name) 以及Apollo/modules/perception/lib/registerer/registerer.h中对PERCEPTION_REGISTER_CLASS(clazz, name) 的宏定义的展开。
并且,向factory_map中注册新的派生类:
__attribute__((constructor)) void RegisterFactoryYoloObstacleDetector() {
::apollo::perception::lib::FactoryMap &map =
::apollo::perception::lib::GlobalFactoryMap()['BaseObstacleDetector']; //key不存在的话则创建一个pair并调用默认构造函数
if (map.find('YoloObstacleDetector') == map.end())
map['YoloObstacleDetector'] = new ObjectFactoryYoloObstacleDetector();
}
说明:
a. __attribute__((constructor))表示该方法在main函数之前运行;
b. 嵌套地定义了双层map,即BaseClassMap类型的factory_map
typedef std::map FactoryMap;
typedef std::map BaseClassMap;
BaseClassMap &GlobalFactoryMap() {
static BaseClassMap factory_map;
return factory_map;
}
在本例中添加了元素['BaseObstacleDetector',['YoloObstacleDetector',ObjectFactory *]] 。
在基类定义中注册BaseObstacleDetector (/Apollo/modules/perception/camera/lib/interface/base_obstacle_detector.h),即BaseObstacleDetectorRegisterer类中定义动态创建指向派生类的基类指针的方法。
class BaseObstacleDetectorRegisterer {
typedef ::apollo::perception::lib::Any Any;
typedef ::apollo::perception::lib::FactoryMap FactoryMap;
public:
static BaseObstacleDetector *GetInstanceByName(const ::std::string &name) {
FactoryMap &map =
::apollo::perception::lib::GlobalFactoryMap()['BaseObstacleDetector'];
FactoryMap::iterator iter = map.find(name);
if (iter == map.end()) {
for (auto c : map) {
AERROR << "Instance:" << c.first;
}
AERROR << "Get instance " << name << " failed.";
return nullptr;
}
Any object = iter->second->NewInstance();
return *(object.AnyCast());
}
static std::vector GetAllInstances() {
std::vector instances;
FactoryMap &map =
::apollo::perception::lib::GlobalFactoryMap()['BaseObstacleDetector'];
instances.reserve(map.size());
for (auto item : map) {
Any object = item.second->NewInstance();
instances.push_back(*(object.AnyCast()));
}
return instances;
}
static const ::std::string GetUniqInstanceName() {
FactoryMap &map =
::apollo::perception::lib::GlobalFactoryMap()['BaseObstacleDetector'];
CHECK_EQ(map.size(), 1) << map.size();
return map.begin()->first;
}
static BaseObstacleDetector *GetUniqInstance() {
FactoryMap &map =
::apollo::perception::lib::GlobalFactoryMap()['BaseObstacleDetector'];
CHECK_EQ(map.size(), 1) << map.size();
Any object = map.begin()->second->NewInstance();
return *(object.AnyCast());
}
static bool IsValid(const ::std::string &name) {
FactoryMap &map =
::apollo::perception::lib::GlobalFactoryMap()['BaseObstacleDetector'];
return map.find(name) != map.end();
}
};
备注:以上为PERCEPTION_REGISTER_REGISTERER(BaseObstacleDetector)在Apollo/modules/perception/lib/registerer/registerer.h中对#define PERCEPTION_REGISTER_REGISTERER(base_class)的宏定义的展开。
为了创建多个具体工厂,使用宏定义保证代码简洁性和高度复用性。例如,在/Apollo/modules/perception/camera/lib/interface/base_obstacle_tracker.h中注册基类接口:
PERCEPTION_REGISTER_REGISTERER(BaseObstacleTracker);
#define REGISTER_OBSTACLE_TRACKER(name) \
PERCEPTION_REGISTER_CLASS(BaseObstacleTracker, name)
在其派生类定义中中注册派生类:
REGISTER_OBSTACLE_TRACKER(OMTObstacleTracker);
则factory_map中添加了[BaseObstacleTracker,[OMTObstacleTracker,ObjectFactory *]].
该方法特点:在每个基类接口定义后注册基类,在每个派生类定义后对派生类进行注册,从而维护一个factory_map;在创建时使用基类+派生类的名称索引到相应的工厂创建对象。
Apollo/modules/common/util/factory.h 中定义了另一种使用工厂方法的思路,即使用工厂模板类。
template >
class Factory {
public:
/**
* @brief Registers the class given by the creator function, linking it to id.
* Registration must happen prior to calling CreateObject.
* @param id Identifier of the class being registered
* @param creator Function returning a pointer to an instance of
* the registered class
* @return True if the key id is still available
*/
bool Register(const IdentifierType &id, ProductCreator creator) {
return producers_.insert(std::make_pair(id, creator)).second;
}
...
/**
* @brief Creates and transfers membership of an object of type matching id.
* Need to register id before CreateObject is called. May return nullptr
* silently.
* @param id The identifier of the class we which to instantiate
* @param args the object construction arguments
*/
template
std::unique_ptr CreateObjectOrNull(const IdentifierType &id,
Args &&... args) {
auto id_iter = producers_.find(id);
if (id_iter != producers_.end()) {
return std::unique_ptr(
(id_iter->second)(std::forward(args)...));
}
return nullptr;
}
/**
* @brief Creates and transfers membership of an object of type matching id.
* Need to register id before CreateObject is called.
* @param id The identifier of the class we which to instantiate
* @param args the object construction arguments
*/
template
std::unique_ptr CreateObject(const IdentifierType &id,
Args &&... args) {
auto result = CreateObjectOrNull(id, std::forward(args)...);
AERROR_IF(!result) << "Factory could not create Object of type : " << id;
return result;
}
private:
MapContainer producers_;
};
在Apollo/modules/planning/tasks/task_factory.h中针对task基类对工厂模板类进行了实例化:
class TaskFactory {
public:
...
private:
static apollo::common::util::Factory>> task_factory_;
static std::unordered_map>
default_task_configs_;
};
特点:在Apollo/modules/planning/tasks/task_factory.cc定义中对task的派生类LaneChangeDecider、PATH_LANE_BORROW_DECIDER等集中进行注册与创建。
可参考:https://blog.csdn.net/davidhopper/article/details/79197075