spring 运行时替换 单例bean 实现

前言

kafka、redis 这些中间件可能会出现超时,这时候我们可能会需要修改他们的链接地址,此时就需要动态替换这些bean

思路

首先要知道这bean 是怎么放到spring 容器中的,是放在哪里的?这个不做详细赘述
可以简易的理解为两步

  • 根据需要创建bean
  • 放在DefaultSingletonBeanRegistry三级缓存中,尤其是org.springframework.beans.factory.support.DefaultSingletonBeanRegistry#singletonObjects

代码准备

一个简单的controller

@RequestMapping()
@RestController
@Slf4j
public class DemoController {


    @Autowired
    InService inService;


    @GetMapping("print")
    public void print() {
        System.out.println(inService.getCourseService().toString());
        //System.out.println(courseInService.toString());
    }

}
@Component
@Getter
@Setter
public class CourseInService  {

    @Autowired
    OkService okService;

}
@Service
public class InService implements SpringCacheHelper{
    @Autowired
    CourseInService courseService;


    public CourseInService getCourseService() {
        return courseService;
    }


}

@Service
public class OkService{
}

我们要做的就是替换掉InService 中的courseService

创建bean的大致源码

源码调试很简单,既然知道bean最终会放到singletonObjects,根据代码查看,只有DefaultSingletonBeanRegistry#addSingleton 这个方法会对singletonObjects add 操作,这里打个断点即可

spring 运行时替换 单例bean 实现_第1张图片

然后启动项目就能看到如下的堆栈
在这里插入图片描述

最关键的方法就是AbstractBeanFactory#doGetBean
大致流程就是

  • 先创建bean
  • 在放入singletonObjects

那我们替换的时候也是如此操作

代码

// 原有的bean
CourseInService bean = BaseSpringContext.getApplicationContext().getBean(CourseInService.class);

  log.info("更新之前 {}", bean.toString());

   // 获取beanFactory
   DefaultListableBeanFactory beanFactory = (DefaultListableBeanFactory) BaseSpringContext.getApplicationContext().getAutowireCapableBeanFactory();

   // 获取beanName
   String beanName = beanFactory.getBeanNamesForType(CourseInService.class)[0];


   // mdb
   Method getMergedLocalBeanDefinitionM = MethodUtils.getMatchingMethod(AbstractBeanFactory.class, "getMergedLocalBeanDefinition", String.class);
   getMergedLocalBeanDefinitionM.setAccessible(true);
   RootBeanDefinition definition = (RootBeanDefinition)getMergedLocalBeanDefinitionM.invoke(beanFactory, beanName);


   //反射调用  createBean(beanName, mbd, args);
   Method createBeanM = MethodUtils.getMatchingMethod(AbstractAutowireCapableBeanFactory.class, "createBean", String.class, RootBeanDefinition.class, Object[].class);
   createBeanM.setAccessible(true);
   CourseInService courseInService = (CourseInService) createBeanM.invoke(beanFactory, beanName, definition, null);

   // 反射更新  singletonObjects 中缓存
   Method addSingletonM = MethodUtils.getMatchingMethod(DefaultSingletonBeanRegistry.class, "addSingleton", String.class, Object.class);
   addSingletonM.setAccessible(true);

   addSingletonM.invoke(beanFactory, beanName, courseInService);

   //如果有依赖的需要立马更新,比如Inservice
   Field singletonObjectsF = FieldUtils.getDeclaredField(DefaultSingletonBeanRegistry.class, "singletonObjects", true);
   ConcurrentHashMap<String, Object> singletonObjects = (ConcurrentHashMap) singletonObjectsF.get(beanFactory);
   for (Object value : singletonObjects.values()) {
       // 过滤非项目bean
       if (!value.getClass().getName().startsWith("com.hgf")) {
           continue;
       }

       Field[] declaredFields = value.getClass().getDeclaredFields();
       for (Field declaredField : declaredFields) {
           if (!declaredField.getType().equals(CourseInService.class)) {
               continue;
           }

           log.info("{} 更新依赖", value.getClass().getName());
           FieldUtils.writeField(declaredField, value, courseInService, true);
       }
   }


   log.info("更新之后 {}", bean.toString());

然后在controller加入接口,调用之后打印日志如下

2022-06-10 15:32:05.476  INFO 26322 --- [io-13399-exec-1] c.h.z.controller.DemoController          : 更新之前 com.hgf.zulu_spring_demo.service.CourseInService@5c1388f3
2022-06-10 15:32:07.516  INFO 26322 --- [io-13399-exec-1] c.h.z.controller.DemoController          : com.hgf.zulu_spring_demo.service.InService 更新依赖
2022-06-10 15:32:07.516  INFO 26322 --- [io-13399-exec-1] c.h.z.controller.DemoController          : 更新之后 com.hgf.zulu_spring_demo.service.CourseInService@709060f7

到这里已经差不多了,但是有个小bug,如果依赖的类是切面的话,就会有问题。比如说在controller 中依赖CourseInService

spring 运行时替换 单例bean 实现_第2张图片

@Component
@Aspect
@Order(0)
@Slf4j
public class ValidationAspect {

    private Logger logger = LoggerFactory.getLogger(this.getClass());

    @Pointcut("execution(public * com.hgf.zulu_spring_demo.controller..*.*(..))")
    public void validationPointCut(){};

    @Around("validationPointCut()")
    public Object valid(ProceedingJoinPoint joinPoint) throws Throwable {
        return joinPoint.proceed();
    }


}

这时候调用发现,controller 没有更新依赖,因为此时的controller是proxy类

spring 运行时替换 单例bean 实现_第3张图片

优化

这时候大致两种方案

  • 想办法找到切面类中的真实controller bean,然后修改
  • 直接修改controller bean

第一种方式需要深入研究源码,不可取,第二种方式就是缓存,在类加载完的时候全部加入到自己的缓存中


@Component
public class SpringCacheHelper implements BeanPostProcessor {
    private static final Map<String, Object> beanCacheMap = new ConcurrentHashMap<>(256);

    public static void addCache(String beanName, Object singletonObject) {
        synchronized (beanName) {
            beanCacheMap.put(beanName, singletonObject);
        }
    }


    public static Map<String,Object> getBeanCacheMap(){
        return beanCacheMap;
    }
    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
        if (!bean.getClass().getName().startsWith("com.hgf")) {
            beanCacheMap.put(beanName, bean);
        }
        return bean;
    }
}

然后修改依赖替换部分

Collection<Object> values = SpringCacheHelper.getBeanCacheMap().values();
for (Object value : values) {
    Field[] declaredFields = value.getClass().getDeclaredFields();
    for (Field declaredField : declaredFields) {
        if (!declaredField.getType().equals(CourseInService.class)) {
            continue;
        }

        log.info("{} 更新依赖", value.getClass().getName());
        FieldUtils.writeField(declaredField, value, courseInService, true);
    }
}

大功告成

核心代码

public void registry() throws Exception {

        CourseInService bean = BaseSpringContext.getApplicationContext().getBean(CourseInService.class);

        log.info("更新之前 {}", bean.toString());

        // 获取beanFactory
        DefaultListableBeanFactory beanFactory = (DefaultListableBeanFactory) BaseSpringContext.getApplicationContext().getAutowireCapableBeanFactory();

        // 获取beanName
        String beanName = beanFactory.getBeanNamesForType(CourseInService.class)[0];


        // mdb
        Method getMergedLocalBeanDefinitionM = MethodUtils.getMatchingMethod(AbstractBeanFactory.class, "getMergedLocalBeanDefinition", String.class);
        getMergedLocalBeanDefinitionM.setAccessible(true);
        RootBeanDefinition definition = (RootBeanDefinition)getMergedLocalBeanDefinitionM.invoke(beanFactory, beanName);


        //createBean(beanName, mbd, args);
        Method createBeanM = MethodUtils.getMatchingMethod(AbstractAutowireCapableBeanFactory.class, "createBean", String.class, RootBeanDefinition.class, Object[].class);
        createBeanM.setAccessible(true);
        CourseInService courseInService = (CourseInService) createBeanM.invoke(beanFactory, beanName, definition, null);

        // 更新  singletonObjects 中缓存
        Method addSingletonM = MethodUtils.getMatchingMethod(DefaultSingletonBeanRegistry.class, "addSingleton", String.class, Object.class);
        addSingletonM.setAccessible(true);

        addSingletonM.invoke(beanFactory, beanName, courseInService);

        //如果有依赖的需要立马更新
        Field singletonObjectsF = FieldUtils.getDeclaredField(DefaultSingletonBeanRegistry.class, "singletonObjects", true);
        ConcurrentHashMap<String, Object> singletonObjects = (ConcurrentHashMap) singletonObjectsF.get(beanFactory);
        for (Object value : singletonObjects.values()) {
            // 过滤非项目bean
            if (!value.getClass().getName().startsWith("com.hgf")) {
                continue;
            }

            Field[] declaredFields = value.getClass().getDeclaredFields();
            for (Field declaredField : declaredFields) {
                if (!declaredField.getType().equals(CourseInService.class)) {
                    continue;
                }

                log.info("{} 更新依赖", value.getClass().getName());
                FieldUtils.writeField(declaredField, value, courseInService, true);
            }
        }


        bean = BaseSpringContext.getApplicationContext().getBean(CourseInService.class);

        Collection<Object> values = SpringCacheHelper.getBeanCacheMap().values();
        for (Object value : values) {
            Field[] declaredFields = value.getClass().getDeclaredFields();
            for (Field declaredField : declaredFields) {
                if (!declaredField.getType().equals(CourseInService.class)) {
                    continue;
                }

                log.info("{} 更新依赖", value.getClass().getName());
                FieldUtils.writeField(declaredField, value, courseInService, true);
            }
        }


        log.info("更新之后 {}", bean.toString());
    }

你可能感兴趣的:(spring,boot,spring,java)