https://fate-serving.readthedocs.io/en/develop/?query=guest
什么是Fate-Serving
fate-serving是FATE的在线部分,在使用FATE进行联邦建模完成之后,可以使用fate-serving进行包括单笔预测、多笔预测以及多host预测在内的在线联合预测。
模型的初始化流程
在FATE中建好模型之后,通过fate-flow的推送模型脚本可以将模型推送至serving-server。 推送成功之后,serving-server会将该模型相关的预测接口注册进zookeeper, 外部系统可以通过服务发现获取接口地址并调用。
参与方划分
fate在调用在线预测接口时,需要数据使用方(Guest)、数据提供方(Host)双方联合预测,Guest方对模型和特征数据进行业务处理后,Guest方接口参数中的sendToRemoteFeatureData会发往Host端,Host方则是通过自定义的Adaptor跟己方业务系统交互(eg:通过访问远程rpc接口、或者通过访问存储)来获取特征数据,并将获取的特征交给算法模块进行计算,最终得出合并后的预测结果并返回给Guest。
Fate-serving适用jdk1.8+SSM,服务提供HTTP接口和RPC接口(grpc),代码量11w行;Secretflow-serving使用C++17+brpc,服务提供RPC接口,代码量1w行。
Secret-serving将模型的执行拆分成了Exector,实现了动态的调度执行,Fate不具备这样的能力。
Fate比隐语多的能力:
服务发现有两个维度,一个是类似serving里面的queryModel接口,另一个是借助curator实现的zk监听回调的能力。第二个能力没有暴露给用户,而只是在内部使用。
model是推理模块的核心,我们先看这部分。fate官网提供了推理算法的讲解,https://fate-serving.readthedocs.io/en/develop/algo/base/,所以这里只关注调度链路。
model模块的架构如下:
注意BaseComponent只实现了LocalInferenceAware。
model在server模块中被ModeLoader加载,ModeLoader会调用initModel进行模型的初始化。
每个模型对应一个PipelineModelProcessor。
serving-server在收到推送模型的请求后,会在内存中初始化一个PipelineModelProcessor实例。和隐语拆分成executor类似,PipelineModelProcessor也会将model拆分成components。但是,fate拆分的components并不是调度的最小单位,因此没有隐语那种动态执行能力。
fate的model同样适用proto定义,由dslParser进行parse后动态加载每一个components。**这里我们可以看到,通过反射,fate提供了动态模型加载能力。**对于线上服务,动态注册模型能力还是很重要的,可以实现服务的热更新。
public int initModel(Context context, Map<String, byte[]> modelProtoMap) {
if (modelProtoMap != null) {
logger.info("start init pipeline,model components {}", modelProtoMap.keySet());
try {
Map<String, byte[]> newModelProtoMap = changeModelProto(modelProtoMap);
logger.info("after parse pipeline {}", newModelProtoMap.keySet());
Preconditions.checkArgument(newModelProtoMap.get(PIPLELINE_IN_MODEL) != null);
PipelineProto.Pipeline pipeLineProto = PipelineProto.Pipeline.parseFrom(newModelProtoMap.get(PIPLELINE_IN_MODEL));
String dsl = pipeLineProto.getInferenceDsl().toStringUtf8();
dslParser.parseDagFromDSL(dsl);
ArrayList<String> components = dslParser.getAllComponent();
HashMap<String, String> componentModuleMap = dslParser.getComponentModuleMap();
// 调用每一个components的initModel
for (int i = 0; i < components.size(); ++i) {
String componentName = components.get(i);
String className = componentModuleMap.get(componentName);
logger.info("try to get class:{}", className);
try {
// 动态加载components
Class modelClass = Class.forName(this.modelPackage + "." + className);
BaseComponent mlNode = (BaseComponent) modelClass.getConstructor().newInstance();
mlNode.setComponentName(componentName);
byte[] protoMeta = newModelProtoMap.get(componentName + ".Meta");
byte[] protoParam = newModelProtoMap.get(componentName + ".Param");
int returnCode = mlNode.initModel(protoMeta, protoParam);
if (returnCode == Integer.valueOf(StatusCode.SUCCESS)) {
componentMap.put(componentName, mlNode);
pipeLineNode.add(mlNode);
logger.info(" add class {} to pipeline task list", className);
} else {
throw new RuntimeException("init model error");
}
} catch (Exception ex) {
pipeLineNode.add(null);
logger.warn("Can not instance {} class", className);
}
}
} catch (Exception ex) {
logger.info("initModel error:{}", ex);
throw new RuntimeException("initModel error");
}
logger.info("Finish init Pipeline");
return Integer.valueOf(StatusCode.SUCCESS);
} else {
logger.error("model content is null ");
throw new RuntimeException("model content is null");
}
}
PipelineModelProcessor的guestInference同样在server中被调用,它的接口是:
public ReturnResult guestInference(Context context, InferenceRequest inferenceRequest, Map
这里的futureMap并不是特征,而是remote inference的结果。
guestInference首先进行singleLocalPredict,顺序调用components的LocalInferenceAware方法。
之后和remote inference的结果进行合并,顺序调用components的mergeRemoteInference方法。
fate-serving-server的controller层定义了一些HTTP请求,grpc.service定义了rpc请求,因为controller也是构造rpc调用,所以不多介绍。
我们先看ModelService部分,ModelServiceProvider继承了AbstractServingServiceProvider,AbstractServingServiceProvider是一个抽象类,它继承了AbstractServiceAdaptor。
AbstractServiceAdaptor是所有Service和ServiceProvider的公共父抽象类,我们来看它提供了哪些接口和公用方法。
公共方法/变量:
需要实现的接口:
接下来我们看service方法,service会传入一个服务上下文context,context实际上就是一个k-v,记载了执行中的一些信息。
@Override
public OutboundPackage<resp> service(Context context, InboundPackage<req> data) throws RuntimeException {
OutboundPackage<resp> outboundPackage = new OutboundPackage<resp>();
// 将requestInProcess + 1
context.preProcess();
List<Throwable> exceptions = Lists.newArrayList();
context.setReturnCode(StatusCode.SUCCESS);
// main方法退出时,会将此值设为0
if (!isOpen) {
return this.serviceFailInner(context, data, new ShowDownRejectException());
}
if(data.getBody()!=null) {
context.putData(Dict.INPUT_DATA, data.getBody());
}
try {
// 记录服务调用次数
requestInHandle.addAndGet(1);
resp result = null;
context.setServiceName(this.serviceName);
try {
preChain.doPreProcess(context, data, outboundPackage);
// 调用子类方法
result = doService(context, data, outboundPackage);
if (logger.isDebugEnabled()) {
logger.debug("do service, router info: {}, service name: {}, result: {}", JsonUtil.object2Json(data.getRouterInfo()), serviceName, result);
}
} catch (Throwable e) {
exceptions.add(e);
logger.error("do service fail, cause by: {}", e.getMessage());
}
outboundPackage.setData(result);
postChain.doPostProcess(context, data, outboundPackage);
}
ModelService主要用到了ModelServiceProvider这个Bean,它是模型服务的代理,我们来看这边的代码。
ModelServiceProvider使用了ModelManager,负责实际的模型管理,下一节会介绍ModelManager。
ModelServiceProvider提供了下面几个模型服务:
@FateService注解设置AbstractServiceAdaptor的preChain和postChain:
@FateService(name = "modelService", preChain = {
"requestOverloadBreaker"
}, postChain = {
})
@FateService注解设置的chain在admin、service、proxy的Register中被调用:
/**
* 当spring应用启动完成后,onApplicationEvent 方法会被调用
**/
@Override
public void onApplicationEvent(ApplicationReadyEvent applicationEvent) {
String[] beans = applicationContext.getBeanNamesForType(AbstractServiceAdaptor.class);
FlowCounterManager flowCounterManager = applicationContext.getBean(FlowCounterManager.class);
for (String beanName : beans) {
AbstractServiceAdaptor serviceAdaptor = applicationContext.getBean(beanName, AbstractServiceAdaptor.class);
serviceAdaptor.setFlowCounterManager(flowCounterManager);
// 获取被FateService注解的bean
FateService proxyService = serviceAdaptor.getClass().getAnnotation(FateService.class);
Method[] methods = serviceAdaptor.getClass().getMethods();
for (Method method : methods) {
FateServiceMethod fateServiceMethod = method.getAnnotation(FateServiceMethod.class);
if (fateServiceMethod != null) {
String[] names = fateServiceMethod.name();
for (String name : names) {
serviceAdaptor.getMethodMap().put(name, method);
}
}
}
if (proxyService != null) {
serviceAdaptor.setServiceName(proxyService.name());
String[] postChain = proxyService.postChain();
String[] preChain = proxyService.preChain();
for (String post : postChain) {
Interceptor postInterceptor = applicationContext.getBean(post, Interceptor.class);
serviceAdaptor.addPostProcessor(postInterceptor);
}
for (String pre : preChain) {
Interceptor preInterceptor = applicationContext.getBean(pre, Interceptor.class);
serviceAdaptor.addPreProcessor(preInterceptor);
}
this.serviceAdaptorMap.put(proxyService.name(), serviceAdaptor);
}
}
logger.info("service register info {}", this.serviceAdaptorMap.keySet());
}
ModelManager是非常重要的模块,负责模型服务的实际执行,我们分别看下上面提到的几个执行方法。
绑定的作用是给service id绑定一个已有的模型。
这里会维护一个serviceid -> key的映射(下图来自官网):
模型池就是namespaceMap,存储一个模型名称到ModelProcessor的映射关系。
注意,每次操作都会进行本地缓存的更新,本地缓存用于服务恢复。
public synchronized ReturnResult bind(Context context, ModelServiceProto.PublishRequest req) {
if (logger.isDebugEnabled()) {
logger.debug("try to bind model, receive request : {}", req);
}
ReturnResult returnResult = new ReturnResult();
String serviceId = req.getServiceId();
Preconditions.checkArgument(StringUtils.isNotBlank(serviceId), "param service id is blank");
Preconditions.checkArgument(!StringUtils.containsAny(serviceId, URL_FILTER_CHARACTER), "Service id contains special characters, " + JsonUtil.object2Json(URL_FILTER_CHARACTER));
returnResult.setRetcode(StatusCode.SUCCESS);
Model model = this.buildModelForBind(context, req);
String modelKey = this.getNameSpaceKey(model.getTableName(), model.getNamespace());
Model loadedModel = this.namespaceMap.get(modelKey);
if (loadedModel == null) {
throw new ModelNullException("model " + modelKey + " is not exist ");
}
this.serviceIdNamespaceMap.put(serviceId, modelKey);
if (zookeeperRegistry != null) {
if (StringUtils.isNotEmpty(serviceId)) {
zookeeperRegistry.addDynamicEnvironment(serviceId);
}
zookeeperRegistry.register(FateServer.guestServiceSets, Lists.newArrayList(serviceId));
}
//update cache
this.store(serviceIdNamespaceMap, serviceIdFile);
return returnResult;
}
private Model buildModelForBind(Context context, ModelServiceProto.PublishRequest req) {
// 从请求的modelMap中读取mode info,
// 可以发现,这里用的全都是从req读出来的数据
Model model = new Model();
String role = req.getLocal().getRole();
model.setPartId(req.getLocal().getPartyId());
model.setRole(Dict.GUEST.equals(role) ? Dict.GUEST : Dict.HOST);
String serviceId = req.getServiceId();
model.getServiceIds().add(serviceId);
Map<String, ModelServiceProto.RoleModelInfo> modelMap = req.getModelMap();
ModelServiceProto.RoleModelInfo roleModelInfo = modelMap.get(model.getRole());
Map<String, ModelServiceProto.ModelInfo> modelInfoMap = roleModelInfo.getRoleModelInfoMap();
Map<String, ModelServiceProto.Party> roleMap = req.getRoleMap();
ModelServiceProto.Party selfParty = roleMap.get(model.getRole());
String selfPartyId = selfParty.getPartyIdList().get(0);
ModelServiceProto.ModelInfo selfModelInfo = modelInfoMap.get(selfPartyId);
String selfNamespace = selfModelInfo.getNamespace();
String selfTableName = selfModelInfo.getTableName();
model.setNamespace(selfNamespace);
model.setTableName(selfTableName);
return model;
}
这里数据提供方(host)加载模型时,记录数据使用方(guest) name + namespace -> (host) model 映射关系,实现使用方和提供方模型的一一对应。
partnerModelMap在guest方始终为空。namespaceMap在host和guest方都存在,记录本地模型池映射关系。
public synchronized ReturnResult load(Context context, ModelServiceProto.PublishRequest req) {
if (logger.isDebugEnabled()) {
logger.debug("try to load model, receive request : {}", req);
}
ReturnResult returnResult = new ReturnResult();
returnResult.setRetcode(StatusCode.SUCCESS);
Model model = this.buildModelForLoad(context, req);
String namespaceKey = this.getNameSpaceKey(model.getTableName(), model.getNamespace());
ModelLoader.ModelLoaderParam modelLoaderParam = new ModelLoader.ModelLoaderParam();
String loadType = req.getLoadType();
if (StringUtils.isNotEmpty(loadType)) {
modelLoaderParam.setLoadModelType(ModelLoader.LoadModelType.valueOf(loadType));
} else {
modelLoaderParam.setLoadModelType(ModelLoader.LoadModelType.FATEFLOW);
}
modelLoaderParam.setTableName(model.getTableName());
modelLoaderParam.setNameSpace(model.getNamespace());
modelLoaderParam.setFilePath(req.getFilePath());
ModelLoader modelLoader = this.modelLoaderFactory.getModelLoader(context, modelLoaderParam.getLoadModelType());
Preconditions.checkArgument(modelLoader != null, "model loader not found");
ModelProcessor modelProcessor = modelLoader.loadModel(context, modelLoaderParam);
if (modelProcessor == null) {
throw new ModelProcessorInitException("model initialization error, please check if the model exists and the configuration of the FATEFLOW load model process is correct.");
}
model.setModelProcessor(modelProcessor);
modelProcessor.setModel(model);
// 本地模型池映射关系
this.namespaceMap.put(namespaceKey, model);
// 数据提供方(host)加载模型时,记录数据使用方(guest) name + namespace -> (host) model 映射关系
// 实现使用方和提供方模型的一一对应
if (Dict.HOST.equals(model.getRole())) {
model.getFederationModelMap().values().forEach(remoteModel -> {
String remoteNamespaceKey = this.getNameSpaceKey(remoteModel.getTableName(), remoteModel.getNamespace());
this.partnerModelMap.put(remoteNamespaceKey, model);
});
}
/**
* host model
*/
if (Dict.HOST.equals(model.getRole()) && zookeeperRegistry != null) {
String modelKey = ModelUtil.genModelKey(model.getTableName(), model.getNamespace());
zookeeperRegistry.addDynamicEnvironment(EncryptUtils.encrypt(modelKey, EncryptMethod.MD5));
zookeeperRegistry.register(FateServer.hostServiceSets);
}
// update cache
this.store(namespaceMap, namespaceFile);
return returnResult;
}
buildModelForLoad执行实际的模型动态加载:
private Model buildModelForLoad(Context context, ModelServiceProto.PublishRequest req) {
Model model = new Model();
String role = req.getLocal().getRole();
model.setPartId(req.getLocal().getPartyId());
model.setRole(Dict.GUEST.equals(role) ? Dict.GUEST : Dict.HOST);
Map<String, ModelServiceProto.RoleModelInfo> modelMap = req.getModelMap();
ModelServiceProto.RoleModelInfo roleModelInfo = modelMap.get(model.getRole());
Map<String, ModelServiceProto.ModelInfo> modelInfoMap = roleModelInfo.getRoleModelInfoMap();
Map<String, ModelServiceProto.Party> roleMap = req.getRoleMap();
String remotePartyRole = model.getRole().equals(Dict.GUEST) ? Dict.HOST : Dict.GUEST;
ModelServiceProto.Party remoteParty = roleMap.get(remotePartyRole);
List<String> remotePartyIdList = remoteParty.getPartyIdList();
for (String remotePartyId : remotePartyIdList) {
ModelServiceProto.RoleModelInfo remoteRoleModelInfo = modelMap.get(remotePartyRole);
ModelServiceProto.ModelInfo remoteModelInfo = remoteRoleModelInfo.getRoleModelInfoMap().get(remotePartyId);
Model remoteModel = new Model();
remoteModel.setPartId(remotePartyId);
remoteModel.setNamespace(remoteModelInfo.getNamespace());
remoteModel.setTableName(remoteModelInfo.getTableName());
remoteModel.setRole(remotePartyRole);
model.getFederationModelMap().put(remotePartyId, remoteModel);
}
ModelServiceProto.Party selfParty = roleMap.get(model.getRole());
String selfPartyId = selfParty.getPartyIdList().get(0);
ModelServiceProto.ModelInfo selfModelInfo = modelInfoMap.get(model.getPartId());
Preconditions.checkArgument(selfModelInfo != null, "model info is invalid");
String selfNamespace = selfModelInfo.getNamespace();
String selfTableName = selfModelInfo.getTableName();
model.setNamespace(selfNamespace);
model.setTableName(selfTableName);
// 从FATEFLOW中加载模型
if (ModelLoader.LoadModelType.FATEFLOW.name().equals(req.getLoadType())) {
try {
ModelLoader.ModelLoaderParam modelLoaderParam = new ModelLoader.ModelLoaderParam();
modelLoaderParam.setLoadModelType(ModelLoader.LoadModelType.FATEFLOW);
modelLoaderParam.setTableName(model.getTableName());
modelLoaderParam.setNameSpace(model.getNamespace());
modelLoaderParam.setFilePath(req.getFilePath());
ModelLoader modelLoader = this.modelLoaderFactory.getModelLoader(context, ModelLoader.LoadModelType.FATEFLOW);
model.setResourceAdress(getAdressForUrl(modelLoader.getResource(context, modelLoaderParam)));
} catch (Exception e) {
logger.error("getloadModelUrl error = {}", e);
}
}
return model;
}
加载时只会在数据提供方进行服务注册:
/**
* host model
*/
if (Dict.HOST.equals(model.getRole()) && zookeeperRegistry != null) {
String modelKey = ModelUtil.genModelKey(model.getTableName(), model.getNamespace());
zookeeperRegistry.addDynamicEnvironment(EncryptUtils.encrypt(modelKey, EncryptMethod.MD5));
zookeeperRegistry.register(FateServer.hostServiceSets);
}
那么这里的DynamicEnvironment作用是什么呢?FateServer.hostServiceSets又是在什么时候被注册的呢?
首先我们看下FateServer.hostServiceSets的初始化,通过阅读源码我们可以发现,在ServingServer这个bean实现了InitializingBean,在初始化完成之后,会调用下面这一段代码,这段代码注册了Fate-Serving需要初始化的几个服务,后面我们可以看到,新的服务都是由这几个初始服务衍生的。
@Override
public void afterPropertiesSet() throws Exception {
logger.info("try to star server ,meta info {}", MetaInfo.toMap());
Executor executor = new ThreadPoolExecutor(MetaInfo.PROPERTY_SERVING_CORE_POOL_SIZE, MetaInfo.PROPERTY_SERVING_MAX_POOL_SIZE, MetaInfo.PROPERTY_SERVING_POOL_ALIVE_TIME, TimeUnit.MILLISECONDS,
MetaInfo.PROPERTY_SERVING_POOL_QUEUE_SIZE == 0 ? new SynchronousQueue<Runnable>() :
(MetaInfo.PROPERTY_SERVING_POOL_QUEUE_SIZE < 0 ? new LinkedBlockingQueue<Runnable>()
: new LinkedBlockingQueue<Runnable>(MetaInfo.PROPERTY_SERVING_POOL_QUEUE_SIZE)), new NamedThreadFactory("ServingServer", true));
FateServerBuilder serverBuilder = (FateServerBuilder) ServerBuilder.forPort(MetaInfo.PROPERTY_SERVER_PORT);
serverBuilder.keepAliveTime(100, TimeUnit.MILLISECONDS);
serverBuilder.executor(executor);
serverBuilder.addService(ServerInterceptors.intercept(guestInferenceService, new ServiceExceptionHandler(), new ServiceOverloadProtectionHandle()), GuestInferenceService.class);
serverBuilder.addService(ServerInterceptors.intercept(modelService, new ServiceExceptionHandler(), new ServiceOverloadProtectionHandle()), ModelService.class);
serverBuilder.addService(ServerInterceptors.intercept(hostInferenceService, new ServiceExceptionHandler(), new ServiceOverloadProtectionHandle()), HostInferenceService.class);
serverBuilder.addService(ServerInterceptors.intercept(commonService, new ServiceExceptionHandler(), new ServiceOverloadProtectionHandle()), CommonService.class);
server = serverBuilder.build();
server.start();
boolean useRegister = MetaInfo.PROPERTY_USE_REGISTER;
if (useRegister) {
logger.info("serving-server is using register center");
zookeeperRegistry.subProject(Dict.PROPERTY_PROXY_ADDRESS);
zookeeperRegistry.subProject(Dict.PROPERTY_FLOW_ADDRESS);
zookeeperRegistry.register(FateServer.serviceSets);
} else {
logger.warn("serving-server not use register center");
}
modelManager.restore(new BaseContext());
logger.warn("serving-server start over");
}
接下来看第二个问题,DynamicEnvironment的作用,我们来看register这里的代码:
public synchronized void register(Set<RegisterService> sets) {
if (logger.isDebugEnabled()) {
logger.debug("prepare to register {}", sets);
}
String hostAddress = NetUtils.getLocalIp();
Preconditions.checkArgument(port != 0);
Preconditions.checkArgument(StringUtils.isNotEmpty(environment));
Set<URL> registered = this.getRegistered();
for (RegisterService service : sets) {
try {
URL url = generateUrl(hostAddress, service);
URL serviceUrl = url.setProject(project);
// 对于推理服务,useDynamicEnvironment为True
if (service.useDynamicEnvironment()) {
if (CollectionUtils.isNotEmpty(dynamicEnvironments)) {
dynamicEnvironments.forEach(environment -> {
URL newServiceUrl = service.protocol().equals(Dict.HTTP) ? url : serviceUrl.setEnvironment(environment);
// use cache service params
loadCacheParams(newServiceUrl);
// 对于每一个environment,生成一个新的service
// 生成的数量是environment的个数*sets的size
String serviceName = service.serviceName() + environment;
if (!registedString.contains(serviceName)) {
this.register(newServiceUrl);
this.registedString.add(serviceName);
} else {
logger.info("url {} is already registed, will not do anything ", newServiceUrl);
}
});
}
} else {
if (!registedString.contains(service.serviceName() + environment)) {
URL newServiceUrl = service.protocol().equals(Dict.HTTP) ? url : serviceUrl.setEnvironment(environment);
if (logger.isDebugEnabled()) {
logger.debug("try to register url {}", newServiceUrl);
}
// use cache service params
loadCacheParams(newServiceUrl);
this.register(newServiceUrl);
this.registedString.add(service.serviceName() + environment);
} else {
logger.info("url {} is already registed, will not do anything ", service.serviceName());
}
}
} catch (Exception e) {
e.printStackTrace();
logger.error("try to register service {} failed", service);
}
}
syncServiceCacheFile();
if (logger.isDebugEnabled()) {
logger.debug("registed urls {}", registered);
}
}
可以看出来,这里通过environment的个数*sets的size的方式,减少了代码复杂度;只在数据提供方注册一次,防止重复注册。
和上面加载的思路一样,只不过绑定只会被guest调用,所以不需要区分guest和host:
if (zookeeperRegistry != null) {
if (StringUtils.isNotEmpty(serviceId)) {
zookeeperRegistry.addDynamicEnvironment(serviceId);
}
// 给guestServiceSets中的每一个服务都注册一个新的serviceId服务
zookeeperRegistry.register(FateServer.guestServiceSets, Lists.newArrayList(serviceId));
}
unload和unregister的代码逻辑差不多,因此就不展开了。
接下来我们看register模块,你会发现这里的代码特别多,因为路由、负载均衡等模块也在这里实现了。
本节主要关注注册逻辑,这样只需要看common和zookeeper两个文件夹就行了,接下来结合官网这张部署实例的图来讲。
首先,我们可以看到,fate-serving不实现zookeeper,zk集群需要客户自己部署。
这里用到的主要是ZookeeperRegistry这个类,我们就从这里展开。
public static ConcurrentMap
是一个URL - > ZookeeperRegistry单例的map。
它的初始化流程如下:
public static synchronized ZookeeperRegistry getRegistry(String url, String project, String environment, int port) {
if (url == null) {
return null;
}
URL registryUrl = URL.valueOf(url);
registryUrl = registryUrl.addParameter(Constants.ENVIRONMENT_KEY, environment);
registryUrl = registryUrl.addParameter(Constants.SERVER_PORT, port);
registryUrl = registryUrl.addParameter(Constants.PROJECT_KEY, project);
List<URL> backups = registryUrl.getBackupUrls();
if (registeryMap.get(registryUrl) == null) {
URL finalRegistryUrl = registryUrl;
registeryMap.computeIfAbsent(registryUrl, n -> {
CuratorZookeeperTransporter curatorZookeeperTransporter = new CuratorZookeeperTransporter();
ZookeeperRegistryFactory zookeeperRegistryFactory = new ZookeeperRegistryFactory();
zookeeperRegistryFactory.setZookeeperTransporter(curatorZookeeperTransporter);
ZookeeperRegistry zookeeperRegistry = (ZookeeperRegistry) zookeeperRegistryFactory.createRegistry(finalRegistryUrl);
return zookeeperRegistry;
});
}
return registeryMap.get(registryUrl);
}
我们先来看CuratorZookeeperTransporter,它负责维护一个Map
我们看ZookeeperClient初始化的过程:
@Override
public ZookeeperClient connect(URL url) {
ZookeeperClient zookeeperClient;
// 解析所有url
List<String> addressList = getURLBackupAddress(url);
// The field define the zookeeper server , including protocol, host, port, username, password
// 更新url->zookeeperClient映射
if ((zookeeperClient = fetchAndUpdateZookeeperClientCache(addressList)) != null && zookeeperClient.isConnected()) {
logger.info("find valid zookeeper client from the cache for address: " + url);
return zookeeperClient;
}
// avoid creating too many connections, so add lock
synchronized (zookeeperClientMap) {
if ((zookeeperClient = fetchAndUpdateZookeeperClientCache(addressList)) != null && zookeeperClient.isConnected()) {
logger.info("find valid zookeeper client from the cache for address: " + url);
return zookeeperClient;
}
zookeeperClient = createZookeeperClient(toClientURL(url));
logger.info("No valid zookeeper client found from cache, therefore create a new client for url. " + url);
writeToClientMap(addressList, zookeeperClient);
// 调度到下面的构造方法
}
return zookeeperClient;
}
public CuratorZookeeperClient(URL url) {
super(url);
try {
// 从 URL 中获取连接超时设置,默认为 5000 毫秒
int timeout = url.getParameter(TIMEOUT_KEY, 5000);
// 使用 CuratorFrameworkFactory.Builder 构建 Curator 客户端
CuratorFrameworkFactory.Builder builder = CuratorFrameworkFactory.builder()
.connectString(url.getBackupAddress()) // 获取连接地址,这里使用了 getBackupAddress 方法
.retryPolicy(new RetryNTimes(1, 1000)) // 设置重试策略,这里是重试一次,每次间隔 1000 毫秒
.connectionTimeoutMs(timeout); // 设置连接超时时间
aclEnable = MetaInfo.PROPERTY_ACL_ENABLE;
if (aclEnable) {
aclUsername = MetaInfo.PROPERTY_ACL_USERNAME;
aclPassword = MetaInfo.PROPERTY_ACL_PASSWORD;
// 如果启用 ACL,检查用户名和密码是否为空
if (StringUtils.isBlank(aclUsername) || StringUtils.isBlank(aclPassword)) {
aclEnable = false;
MetaInfo.PROPERTY_ACL_ENABLE = false;
} else {
// 如果用户名和密码不为空,添加授权信息和 ACL 规则
builder.authorization(SCHEME, (aclUsername + ":" + aclPassword).getBytes());
Id allow = new Id(SCHEME, DigestAuthenticationProvider.generateDigest(aclUsername + ":" + aclPassword));
// add more
acls.add(new ACL(ZooDefs.Perms.ALL, allow));
}
}
// 使用 builder 构建 Curator 客户端
client = builder.build();
// 添加连接状态监听器,处理连接状态变化事件
client.getConnectionStateListenable().addListener(new ConnectionStateListener() {
@Override
public void stateChanged(CuratorFramework client, ConnectionState state) {
// 处理连接状态变化事件,根据不同状态调用 stateChanged 方法
// 只实现了RECONNECTED
if (state == ConnectionState.LOST) {
CuratorZookeeperClient.this.stateChanged(StateListener.DISCONNECTED);
} else if (state == ConnectionState.CONNECTED) {
CuratorZookeeperClient.this.stateChanged(StateListener.CONNECTED);
} else if (state == ConnectionState.RECONNECTED) {
CuratorZookeeperClient.this.stateChanged(StateListener.RECONNECTED);
}
}
});
// 启动 Curator 客户端
client.start();
// 如果启用 ACL,为根节点设置 ACL
if (aclEnable) {
client.setACL().withACL(acls).forPath("/");
}
} catch (Exception e) {
// 处理异常,抛出 IllegalStateException
throw new IllegalStateException(e.getMessage(), e);
}
}
继续看ZookeeperRegistry,在client初始化完后,ZookeeperRegistry会add一个状态监听器,用于断线重连之后服务的恢复。
public ZookeeperRegistry(URL url, ZookeeperTransporter zookeeperTransporter) {
super(url);
String group = url.getParameter(ROOT_KEY, Dict.DEFAULT_FATE_ROOT);
if (!group.startsWith(PATH_SEPARATOR)) {
group = PATH_SEPARATOR + group;
}
this.environment = url.getParameter(ENVIRONMENT_KEY, "online");
project = url.getParameter(PROJECT_KEY);
port = url.getParameter(SERVER_PORT) != null ? new Integer(url.getParameter(SERVER_PORT)) : 0;
this.root = group;
zkClient = zookeeperTransporter.connect(url);
zkClient.addStateListener(state -> {
if (state == StateListener.RECONNECTED) {
logger.error("state listener reconnected");
try {
recover();
} catch (Exception e) {
logger.error(e.getMessage(), e);
}
}
});
}
// recover最后会调用到:
public void addFailedRegisterComponentTask(URL url) {
if(url!=null) {
String instanceId = AbstractRegistry.INSTANCE_ID;
FailedRegisterComponentTask oldOne = this.failedRegisterComponent.get(instanceId);
if (oldOne != null) {
return;
}
// 新的重试任务
FailedRegisterComponentTask newTask = new FailedRegisterComponentTask(url, this);
oldOne = failedRegisterComponent.putIfAbsent(instanceId, newTask);
if (oldOne == null) {
// never has a retry task. then start a new task for retry.
// 设置超时时间,超时后调用doRegisterComponent()
retryTimer.newTimeout(newTask, retryPeriod, TimeUnit.MILLISECONDS);
}
}
}
服务注册最后会调用到下面的client代码:
// 创建临时节点
@Override
public void createEphemeral(String path) {
try {
if (logger.isDebugEnabled()) {
logger.debug("createEphemeral {}", path);
}
if (aclEnable) {
// 如果启用 ACL,则使用指定的 ACL(acls)创建临时节点
client.create().withMode(CreateMode.EPHEMERAL).withACL(acls).forPath(path);
} else {
// 如果未启用 ACL,则以默认权限创建临时节点
client.create().withMode(CreateMode.EPHEMERAL).forPath(path);
}
} catch (NodeExistsException e) {
} catch (Exception e) {
throw new IllegalStateException(e.getMessage(), e);
}
}
// 创建永久节点
@Override
protected void createPersistent(String path, String data) {
byte[] dataBytes = data.getBytes(CHARSET);
try {
if (logger.isDebugEnabled()) {
logger.debug("createPersistent {} data {}", path, data);
}
if (aclEnable) {
client.create().withACL(acls).forPath(path, dataBytes);
} else {
client.create().forPath(path, dataBytes);
}
} catch (NodeExistsException e) {
try {
if (aclEnable) {
Stat stat = client.checkExists().forPath(path);
client.setData().withVersion(stat.getAversion()).forPath(path, dataBytes);
} else {
client.setData().forPath(path, dataBytes);
}
} catch (Exception e1) {
throw new IllegalStateException(e.getMessage(), e1);
}
} catch (Exception e) {
throw new IllegalStateException(e.getMessage(), e);
}
}
subProject实现了服务发现,最终会调用到client.getChildren().usingWatcher(listener).forPath(path):
@Override
public void subProject(String project) {
if (logger.isDebugEnabled()) {
logger.debug("try to subProject: {}", project);
}
super.subProject(project);
failedSubProject.remove(project);
try {
doSubProject(project);
} catch (Exception e) {
addFailedSubscribedProjectTask(project);
}
}
@Override
public void doSubProject(String project) {
String path = root + Constants.PATH_SEPARATOR + project;
// 监听 root + Constants.PATH_SEPARATOR + project
List<String> environments = zkClient.addChildListener(path, (parent, childrens) -> {
if (StringUtils.isNotEmpty(parent)) {
logger.info("fire environments changes {}", childrens);
// 监听新出现的children
subEnvironments(path, project, childrens);
}
});
if (logger.isDebugEnabled()) {
logger.debug("environments {}", environments);
}
if (environments == null) {
if (logger.isDebugEnabled()) {
logger.debug("path {} is not exist in zk", path);
}
throw new RuntimeException("environment is null");
}
subEnvironments(path, project, environments);
}
private void subEnvironments(String path, String project, List<String> environments) {
if (environments != null) {
for (String environment : environments) {
String tempPath = path + Constants.PATH_SEPARATOR + environment;
// 监听 root + Constants.PATH_SEPARATOR + project + onstants.PATH_SEPARATOR + environment
List<String> services = zkClient.addChildListener(tempPath, (parent, childrens) -> {
if (StringUtils.isNotEmpty(parent)) {
if (logger.isDebugEnabled()) {
logger.debug("fire services changes {}", childrens);
}
subServices(project, environment, childrens);
}
});
subServices(project, environment, services);
}
}
}
如果父节点发生了变化,那么就会调用下面的方法,进行订阅:
private void subServices(String project, String environment, List<String> services) {
if (services != null) {
for (String service : services) {
String subString = project + Constants.PATH_SEPARATOR + environment + Constants.PATH_SEPARATOR + service;
if (logger.isDebugEnabled()) {
logger.debug("subServices sub {}", subString);
}
subscribe(URL.valueOf(subString), urls -> {
if (logger.isDebugEnabled()) {
logger.debug("change services urls =" + urls);
}
});
}
}
}
因为在fate-serving中使用的zk结构如下:
yml /FATE-SERVICES/{模块名}/{ID}/{接口名}/provider/{服务提供者信息}
从前面我们可以知道用户新的服务都是由固定的模块生成的,所以用户注册了新的服务之后,也能够被client发现。原始服务的注册在afterPropertiesSet()中进行,上面已经介绍过了。
我们注意到ZookeeperRegistry的基类FailbackRegistry中出现了retryTimer,我们来看下它的实现。
在ZookeeperRegistry和FailbackRegistry中,任务失败后会设置:retryTimer.newTimeout(newTask, retryPeriod, TimeUnit.MILLISECONDS);
来启动一个定时重试任务,它会执行:
HashedWheelTimeout timeout = new HashedWheelTimeout(this, task, deadline);
timeouts.add(timeout);
�把任务加入队列中,worker会poll这个队列,到时间后执行任务。
HashedWheelTimer构造函数会执行worker的初始化逻辑,
workerThread = threadFactory.newThread(worker);
threadFactory是一个名称标记的线程池实现,给每个线程进行了命名。
我们继续看worker这边的run方法:
@Override
public void run() {
// Initialize the startTime.
startTime = System.nanoTime();
if (startTime == 0) {
// We use 0 as an indicator for the uninitialized value here, so make sure it's not 0 when initialized.
startTime = 1;
}
// Notify the other threads waiting for the initialization at start().
// HashedWheelTimer执行线程和worker线程之间同步
// 等待worker初始化完成后才能添加任务
startTimeInitialized.countDown();
do {
final long deadline = waitForNextTick();
if (deadline > 0) {
// 这里将相同tick的timeouts放到同一个bucket,就是所谓的HashedWheelBucket
int idx = (int) (tick & mask);
processCancelledTasks();
HashedWheelBucket bucket =
wheel[idx];
transferTimeoutsToBuckets();
// 过期掉bucket中的所有timeouts
bucket.expireTimeouts(deadline);
tick++;
}
} while (WORKER_STATE_UPDATER.get(HashedWheelTimer.this) == WORKER_STATE_STARTED);
// Fill the unprocessedTimeouts so we can return them from stop() method.
for (HashedWheelBucket bucket : wheel) {
bucket.clearTimeouts(unprocessedTimeouts);
}
for (; ; ) {
// 处理所有的timeouts
HashedWheelTimeout timeout = timeouts.poll();
if (timeout == null) {
break;
}
if (!timeout.isCancelled()) {
unprocessedTimeouts.add(timeout);
}
}
processCancelledTasks();
}
这里的逻辑就很简单,也没用小顶堆,因为这里的过期任务数量其实并不多。
proxy模块用于路由服务的基类是BaseServingRouter,它有两个实现,一个是ConfigFileBasedServingRouter,另一个是ZkServingRouter。被用在如下地方:
register模块用于路由服务的基类的RouterService,它的使用如下所示:
我们看register这边的逻辑,负载均衡主要被路由模块使用,所以就一起看了。
我们可以看到,被AbstractRouterService使用的是LoadBalanceModel.random,RandomLoadBalance只有一个选择算法,按照这个算法,落在权重大的节点中的概率更高。
public class RandomLoadBalance extends AbstractLoadBalancer {
public static final String NAME = "random";
@Override
protected List<URL> doSelect(List<URL> urls) {
// 获取URL列表的长度
int length = urls.size();
// 初始化标志,表示所有URL的权重是否相同
boolean sameWeight = true;
// 初始化数组,用于存储每个URL的权重
int[] weights = new int[length];
// 获取第一个URL的权重,用于后续比较
int firstWeight = getWeight(urls.get(0));
weights[0] = firstWeight;
// 初始化总权重,并加上第一个URL的权重
int totalWeight = firstWeight;
// 遍历剩余的URL,计算总权重,同时检查各个URL的权重是否相同
for (int i = 1; i < length; i++) {
int weight = getWeight(urls.get(i));
weights[i] = weight;
totalWeight += weight;
// 如果有一个URL的权重不同于第一个URL,则标志位置为false
if (sameWeight && weight != firstWeight) {
sameWeight = false;
}
}
// 如果总权重为正且不是所有URL的权重都相同,进行随机选择
if (totalWeight > 0 && !sameWeight) {
// 生成一个随机偏移量,范围在总权重内
int offset = ThreadLocalRandom.current().nextInt(totalWeight);
// 遍历URL列表,根据随机偏移量选择一个URL,使得该URL的权重占比与总权重相匹配
for (int i = 0; i < length; i++) {
offset -= weights[i];
if (offset < 0) {
// 将选定的URL放入列表并返回
return Lists.newArrayList(urls.get(i));
}
}
}
// 特殊情况处理:如果总权重为非正数或所有URL的权重都相同,返回随机选择的URL
return Lists.newArrayList(urls.get(ThreadLocalRandom.current().nextInt(length)));
}
}
权重参数被存放在URL的private volatile transient Map
中,我们可以看到这里没有修改的逻辑,所以最终都会使用默认值。