Apollo5.5源码分析:对象工厂

一、概述

使用工厂模式,根据配置文件动态地生成对象。

二、简单工厂模式

简单工厂模式由一个工厂对象决定创建出哪一种产品类的实例。

代码位置Apollo/modules/perception/inference

用来做深度学习推理的框架CaffeNet\RTNet\PaddleNet都继承自Inference基类。

Apollo5.5源码分析:对象工厂_第1张图片

Inference *CreateInferenceByName(const std::string &name,
                                 const std::string &proto_file,
                                 const std::string &weight_file,
                                 const std::vector &outputs,
                                 const std::vector &inputs,
                                 const std::string &model_root) {
  if (name == "CaffeNet") {
    return new CaffeNet(proto_file, weight_file, outputs, inputs);
  } else if (name == "RTNet") {
    return new RTNet(proto_file, weight_file, outputs, inputs);
  } else if (name == "RTNetInt8") {
    return new RTNet(proto_file, weight_file, outputs, inputs, model_root);
  } else if (name == "PaddleNet") {
    return new PaddleNet(proto_file, weight_file, outputs, inputs);
  }
  return nullptr;
}
  • 工厂(Creator): CreateInferenceByName, 包含逻辑判断,根据外界给定信息,决定创建哪个具体产品的对象.
  • 抽象产品(Product): Inference, 负责描述所有实例所共有的公共接口。
  • 具体产品(Concrete Product): CaffeNet\RTNet\PaddleNet

Client使用示例:Apollo/modules/perception/camera/lib/obstacle/detector/yolo/yolo_obstacle_detector中:

std::shared_ptr inference_;  //yolo_obstacle_detector.h
inference_.reset(inference::CreateInferenceByName(model_type, proto_file,weight_file, output_names,input_names, model_root)); //yolo_obstacle_detector.cc

在使用示例中,/home/yly/pytorch_model/Apollo/modules/perception/camera/test/camera_lib_obstacle_detector_yolo_yolo_obstacle_detector_test.cc读取配置文件/Apollo/modules/perception/testdata/camera/lib/obstacle/detector/yolo/data/config.pt中model_type参数,发现是"RTNet",从而生成new RTNet(proto_file, weight_file, outputs, inputs);

优点:实现方便,适合负责创建的对象比较少的场景。在Inference中,推理框架一般不会频繁增加,因此使用简单工厂模式非常合适。

缺点:由于工厂类集中了所有实例的创建逻辑,它所能创建的类只能是事先考虑到的,如果需要添加新的类,则就需要改变工厂类了,破坏“开闭原则”。

三、工厂方法模式(基于宏定义)

工厂方法定义一个抽象工厂,由抽象工厂负责定义产品的生产接口,但不负责生产具体的产品,将生产任务交给不同的派生类工厂。这样不用通过指定具体类型来创建对象了(使用基类名称即可)。

Apollo5.5源码分析:对象工厂_第2张图片

优点:工厂方法模式实现“开闭原则”,保证可扩展性。

代码位置/home/yly/pytorch_model/Apollo/modules/perception/camera模块。

在https://blog.csdn.net/Cxiazaiyu/article/details/106256330中我们讲到Camera模块做了非常好的架构设计,interface文件夹中定义了功能接口,如BaseObstacleDetector(Apollo/modules/perception/camera/lib/interface/base_obstacle_detector.h) ,在app中(Apollo/modules/perception/camera/app/obstacle_camera_perception.cc)中使用基类接口创建对象。

Client使用示例:通过ObjectFactory的指针使用创建对象的方法,返回的是BaseObstacleDetector类的指针。

std::shared_ptr detector_ptr(BaseObstacleDetectorRegisterer::GetInstanceByName(plugin_param.name()));

实际创建的对象则是根据配置文件中载入的参数动态地创建对象,例如Apollo/modules/perception/camera/test/camera_app_obstacle_camera_perception_test.cc中调用配置文件/Apollo/modules/perception/testdata/camera/app/conf/perception/camera/obstacle/obstacle.pt,

detector_param {
  plugin_param{
    name: "YoloObstacleDetector"
    root_dir: "/apollo/modules/perception/production/data/perception/camera/models/yolo_obstacle_detector"
    config_file: "config.pt"
  }
  camera_name : "front_6mm"
}

从而,根据配置文件确定创建BaseObstacleDetector的派生类YoloObstacleDetector的实例。

 

  •  抽象工厂(Creator): Apollo/modules/perception/lib/registerer/registerer中定义
class ObjectFactory {
 public:
  ObjectFactory() {}
  virtual ~ObjectFactory() {}
  virtual Any NewInstance() { return Any(); }
  ObjectFactory(const ObjectFactory &) = delete;
  ObjectFactory &operator=(const ObjectFactory &) = delete;

 private:
};
  • 具体工厂(Concrete Creator):

在派生类定义中(/Apollo/modules/perception/camera/lib/obstacle/detector/yolo/yolo_obstacle_detector.cc),根据抽象工厂派生具体工厂:

class ObjectFactoryYoloObstacleDetector:public apollo::perception::lib::ObjectFactory { 
   public:                                                                    
    virtual ~ObjectFactoryYoloObstacleDetector() {}                                         
    virtual ::apollo::perception::lib::Any NewInstance() {                    
      return ::apollo::perception::lib::Any(new YoloObstacleDetector());                      
    }                                                                         
};  
                                                                               

在具体工厂中创建具体产品YoloObstacleDetector.

备注:以上为REGISTER_OBSTACLE_DETECTOR(YoloObstacleDetector)在Apollo/modules/perception/camera/lib/interface/base_obstacle_detector.h中的宏定义#define REGISTER_OBSTACLE_DETECTOR(name) \
  PERCEPTION_REGISTER_CLASS(BaseObstacleDetector, name) 以及Apollo/modules/perception/lib/registerer/registerer.h中对PERCEPTION_REGISTER_CLASS(clazz, name) 的宏定义的展开。

并且,向factory_map中注册新的派生类:

__attribute__((constructor)) void RegisterFactoryYoloObstacleDetector() {                 
    ::apollo::perception::lib::FactoryMap &map =                              
        ::apollo::perception::lib::GlobalFactoryMap()['BaseObstacleDetector'];  //key不存在的话则创建一个pair并调用默认构造函数              
    if (map.find('YoloObstacleDetector') == map.end()) 
        map['YoloObstacleDetector'] = new ObjectFactoryYoloObstacleDetector(); 
  } 

 说明:

a. __attribute__((constructor))表示该方法在main函数之前运行;

b. 嵌套地定义了双层map,即BaseClassMap类型的factory_map

typedef std::map FactoryMap;
typedef std::map BaseClassMap;
BaseClassMap &GlobalFactoryMap() {
  static BaseClassMap factory_map;
  return factory_map;
}

在本例中添加了元素['BaseObstacleDetector',['YoloObstacleDetector',ObjectFactory *]] 。

 

  • 抽象产品(Product):BaseObstacleDetector

在基类定义中注册BaseObstacleDetector (/Apollo/modules/perception/camera/lib/interface/base_obstacle_detector.h),即BaseObstacleDetectorRegisterer类中定义动态创建指向派生类的基类指针的方法。

class BaseObstacleDetectorRegisterer {                                      
    typedef ::apollo::perception::lib::Any Any;                       
    typedef ::apollo::perception::lib::FactoryMap FactoryMap;                                                                    
   public:                                                            
    static BaseObstacleDetector *GetInstanceByName(const ::std::string &name) { 
      FactoryMap &map =                                               
          ::apollo::perception::lib::GlobalFactoryMap()['BaseObstacleDetector']; 
      FactoryMap::iterator iter = map.find(name);                     
      if (iter == map.end()) {                                        
        for (auto c : map) {                                          
          AERROR << "Instance:" << c.first;                           
        }                                                             
        AERROR << "Get instance " << name << " failed.";              
        return nullptr;                                               
      }                                                               
      Any object = iter->second->NewInstance();                       
      return *(object.AnyCast());                       
    }                                                                 
    static std::vector GetAllInstances() {              
      std::vector instances;                            
      FactoryMap &map =                                               
          ::apollo::perception::lib::GlobalFactoryMap()['BaseObstacleDetector']; 
      instances.reserve(map.size());                                  
      for (auto item : map) {                                         
        Any object = item.second->NewInstance();                      
        instances.push_back(*(object.AnyCast()));       
      }                                                               
      return instances;                                               
    }                                                                 
    static const ::std::string GetUniqInstanceName() {                
      FactoryMap &map =                                               
          ::apollo::perception::lib::GlobalFactoryMap()['BaseObstacleDetector']; 
      CHECK_EQ(map.size(), 1) << map.size();                          
      return map.begin()->first;                                      
    }                                                                 
    static BaseObstacleDetector *GetUniqInstance() {                            
      FactoryMap &map =                                               
          ::apollo::perception::lib::GlobalFactoryMap()['BaseObstacleDetector']; 
      CHECK_EQ(map.size(), 1) << map.size();                          
      Any object = map.begin()->second->NewInstance();               
      return *(object.AnyCast());                       
    }                                                                 
    static bool IsValid(const ::std::string &name) {                  
      FactoryMap &map =                                               
          ::apollo::perception::lib::GlobalFactoryMap()['BaseObstacleDetector']; 
      return map.find(name) != map.end();                             
    }                                                                 
};

备注:以上为PERCEPTION_REGISTER_REGISTERER(BaseObstacleDetector)在Apollo/modules/perception/lib/registerer/registerer.h中对#define PERCEPTION_REGISTER_REGISTERER(base_class)的宏定义的展开。

  • 具体产品(Concrete Product): YoloObstacleDetector (开发者也可以基于BaseObstacleDetector在lib中实现自己的具体产品)。

 

为了创建多个具体工厂,使用宏定义保证代码简洁性和高度复用性。例如,在/Apollo/modules/perception/camera/lib/interface/base_obstacle_tracker.h中注册基类接口:

PERCEPTION_REGISTER_REGISTERER(BaseObstacleTracker);
#define REGISTER_OBSTACLE_TRACKER(name) \
  PERCEPTION_REGISTER_CLASS(BaseObstacleTracker, name)

在其派生类定义中中注册派生类:

REGISTER_OBSTACLE_TRACKER(OMTObstacleTracker);

则factory_map中添加了[BaseObstacleTracker,[OMTObstacleTracker,ObjectFactory *]].

该方法特点:在每个基类接口定义后注册基类,在每个派生类定义后对派生类进行注册,从而维护一个factory_map;在创建时使用基类+派生类的名称索引到相应的工厂创建对象。

四、工厂方法模式(基于模板类)

Apollo/modules/common/util/factory.h 中定义了另一种使用工厂方法的思路,即使用工厂模板类。

Apollo5.5源码分析:对象工厂_第3张图片

template >
class Factory {
 public:
  /**
   * @brief Registers the class given by the creator function, linking it to id.
   * Registration must happen prior to calling CreateObject.
   * @param id Identifier of the class being registered
   * @param creator Function returning a pointer to an instance of
   * the registered class
   * @return True if the key id is still available
   */
  bool Register(const IdentifierType &id, ProductCreator creator) {
    return producers_.insert(std::make_pair(id, creator)).second;
  }

  ...


  /**
   * @brief Creates and transfers membership of an object of type matching id.
   * Need to register id before CreateObject is called. May return nullptr
   * silently.
   * @param id The identifier of the class we which to instantiate
   * @param args the object construction arguments
   */
  template 
  std::unique_ptr CreateObjectOrNull(const IdentifierType &id,
                                                      Args &&... args) {
    auto id_iter = producers_.find(id);
    if (id_iter != producers_.end()) {
      return std::unique_ptr(
          (id_iter->second)(std::forward(args)...));
    }
    return nullptr;
  }

  /**
   * @brief Creates and transfers membership of an object of type matching id.
   * Need to register id before CreateObject is called.
   * @param id The identifier of the class we which to instantiate
   * @param args the object construction arguments
   */
  template 
  std::unique_ptr CreateObject(const IdentifierType &id,
                                                Args &&... args) {
    auto result = CreateObjectOrNull(id, std::forward(args)...);
    AERROR_IF(!result) << "Factory could not create Object of type : " << id;
    return result;
  }

 private:
  MapContainer producers_;
};

在Apollo/modules/planning/tasks/task_factory.h中针对task基类对工厂模板类进行了实例化:

class TaskFactory {
 public:
       ...
 private:
  static apollo::common::util::Factory>> task_factory_;
  static std::unordered_map>
      default_task_configs_;
};

特点:在Apollo/modules/planning/tasks/task_factory.cc定义中对task的派生类LaneChangeDecider、PATH_LANE_BORROW_DECIDER等集中进行注册与创建。

可参考:https://blog.csdn.net/davidhopper/article/details/79197075

 

你可能感兴趣的:(自动驾驶,设计模式)