Springboot 2.X 整合mongo 多数据源动态切换

初始化工作


加入如下依赖

 <!--mongodb依赖-->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-mongodb</artifactId>
            <version>2.2.5.RELEASE</version>
        </dependency>

增加application.yaml配置

spring:
  aop:
    auto: true
    proxy-target-class: true
mongo:
  datasource:
    dblist:
      - uri: mongodb://xxx:xxx@loaclhost:27017/basicdata
        database: basicdata #对应MongoSwitch注解的值
      - uri: mongodb://xxx:xxx@loaclhost:27017/common
        database: common #对应MongoSwitch注解的值

在启动程序入口增加@ConfigurationPropertiesScan注解

@SpringBootApplication
@ConfigurationPropertiesScan
@Slf4j
public class PushServerApplication {

  public static void main(String[] args) {
    SpringApplication.run(PushServerApplication.class, args);
    log.info("服务启动!");
  }
}

增加自定义配置读取类MongoListProperties

import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import org.springframework.boot.context.properties.ConfigurationProperties;

import java.util.List;

/**
 * @program: DataPushServer
 * @description: mongo连接配置类
 * @author: linwl
 * @create: 2020-07-11 14:41
 */
@Getter
@Setter
@ToString
@ConfigurationProperties(prefix = "mongo.datasource")
public class MongoListProperties {

  private List<MongoList> dblist;

  @Data
  public static class MongoList {
    private String uri;
    private String database;
  }
}

接下来配置mongo数据源


增加mongo数据源上下文切换MongoDbContext

mport org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.data.mongodb.MongoDbFactory;
import org.springframework.data.mongodb.core.SimpleMongoClientDbFactory;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import javax.annotation.PostConstruct;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/**
 * @program: DataPushServer
 * @description: mongdb数据库连接上下文
 * @author: linwl
 * @create: 2020-07-11 14:31
 */
@Component
public class MongoDbContext {

  private static final Map<String, MongoDbFactory> MONGO_CLIENT_DB_FACTORY_MAP = new HashMap<>();
  private static final ThreadLocal<MongoDbFactory> MONGO_DB_FACTORY_THREAD_LOCAL =
      new ThreadLocal<>();
  @Autowired MongoListProperties mongoListProperties;

  public static MongoDbFactory getMongoDbFactory() {
    return MONGO_DB_FACTORY_THREAD_LOCAL.get();
  }

  public static void setMongoDbFactory(String name) {
    MONGO_DB_FACTORY_THREAD_LOCAL.set(MONGO_CLIENT_DB_FACTORY_MAP.get(name));
  }

  public static void removeMongoDbFactory() {
    MONGO_DB_FACTORY_THREAD_LOCAL.remove();
  }

  @PostConstruct
  public void afterPropertiesSet() {
    if (!CollectionUtils.isEmpty(mongoListProperties.getDblist())) {
      mongoListProperties
          .getDblist()
          .forEach(
              info -> {
                MONGO_CLIENT_DB_FACTORY_MAP.put(
                    info.getDatabase(), new SimpleMongoClientDbFactory(info.getUri()));
              });
    }
  }

  @Bean(name = "mongoTemplate")
  public MultiMongoTemplate dynamicMongoTemplate() {
    Iterator<MongoDbFactory> iterator = MONGO_CLIENT_DB_FACTORY_MAP.values().iterator();
    return new MultiMongoTemplate(iterator.next());
  }

  @Bean(name = "mongoDbFactory")
  public MongoDbFactory mongoDbFactory() {
    Iterator<MongoDbFactory> iterator = MONGO_CLIENT_DB_FACTORY_MAP.values().iterator();
    return iterator.next();
  }
}

增加MongoSwitch注解

import java.lang.annotation.*;

/**
 * @program: DataPushServer
 * @description: mongo数据源切换注解
 * @author: linwl
 * @create: 2020-07-11 15:06
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MongoSwitch {

  /**
   * mongo数据库名称
   *
   * @return
   */
  String value() default "";
}

增加aop配置类

import com.telpo.datapushserver.annotation.MongoSwitch;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.After;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.text.MessageFormat;

/**
 * @program: DataPushServer
 * @description: mongoDb数据源切换
 * @author: linwl
 * @create: 2020-07-11 14:26
 */
@Component
@Aspect
@Order(value = -99)
@Slf4j
public class MongoDbSwitch {

  @Pointcut("@annotation(com.*.annotation.MongoSwitch)")
  public void mongoSwitch() {}

  @Before("mongoSwitch()")
  public void before(JoinPoint point) {
    try {
      MethodSignature methodSignature = (MethodSignature) point.getSignature();
      Method method =
          point
              .getTarget()
              .getClass()
              .getMethod(methodSignature.getName(), methodSignature.getParameterTypes());
      MongoSwitch mongoSwitch = method.getAnnotation(MongoSwitch.class);
      log.info(MessageFormat.format("切换 {0}数据源", mongoSwitch.value()));
      MongoDbContext.setMongoDbFactory(mongoSwitch.value());
    } catch (Exception e) {
      log.error("==========>前置mongo数据源切换异常", e);
    }
  }

  @After("mongoSwitch()")
  public void after(JoinPoint point) {
    try {
      log.info("移除mongo数据源上下文");
      MongoDbContext.removeMongoDbFactory();
    } catch (Exception e) {
      log.error("==========>后置mongo数据源切换异常", e);
    }
  }
}

增加mongo多数据源配置MultiMongoTemplate

import com.mongodb.client.MongoDatabase;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.mongodb.MongoDbFactory;
import org.springframework.data.mongodb.core.MongoTemplate;

/**
 * @program: DataPushServer
 * @description: 多mongo数据源配置
 * @author: linwl
 * @create: 2020-07-11 14:21
 */
@Slf4j
public class MultiMongoTemplate extends MongoTemplate {

  public MultiMongoTemplate(MongoDbFactory mongoDbFactory) {
    super(mongoDbFactory);
  }

  @Override
  protected MongoDatabase doGetDatabase() {
    MongoDbFactory mongoDbFactory = MongoDbContext.getMongoDbFactory();
    return mongoDbFactory == null ? super.doGetDatabase() : mongoDbFactory.getDb();
  }
}

使用方法

先定义一个通用mongo模板抽象类,目的为了规范使用(自行酌情增加)

import java.util.Collection;
import java.util.List;
import java.util.Map;


/**
 * @program: DataPushServer
 * @description: 通用抽象类
 * @author: linwl
 * @create: 2020-07-11 15:26
 */
@Slf4j
public abstract class BaseMongoDbDao<T extends BaseMongoDbEntity> {

    @Autowired private MongoTemplate mongoTemplate;

    /**
     * 反射获取泛型类型
     *
     * @return
     */
    protected abstract Class<T> getEntityClass();

    /**
     * 保存
     *
     * @param t
     * @param collectionName 集合名称
     */
    public void save(T t, String collectionName) {
        log.info("-------------->MongoDB save start");
        this.mongoTemplate.save(t, collectionName);
    }

    /**
     * 保存
     *
     * @param t
     */
    public void save(T t) {
        log.info("-------------->MongoDB save start");
        this.mongoTemplate.save(t);
    }

    /**
     * 批量保存
     *
     * @param objectsToSave
     */
    public Collection<T> batchSave(Collection<? extends T> objectsToSave) {
        log.info("-------------->MongoDB batch save start");
        return this.mongoTemplate.insert(objectsToSave,this.getEntityClass());
    }

    /**
     * 批量保存
     * @param objectsToSave
     * @param collectionName
     */
    public Collection<T> batchSave(Collection<? extends T> objectsToSave, String collectionName) {
        log.info("-------------->MongoDB batch save start");
        return this.mongoTemplate.insert(objectsToSave,collectionName);
    }

    /**
     * * 根据id从几何中查询对象
     *
     * @param id
     * @return
     */
    public T queryById(String id) {
        Query query = new Query(Criteria.where("_id").is(id));
        log.info("-------------->MongoDB find start");
        return this.mongoTemplate.findOne(query, this.getEntityClass());
    }

    /**
     * 根据id从几何中查询对象
     * @param id
     * @param collectionName
     * @return
     */
    public T queryById(String id,String collectionName) {
        Query query = new Query(Criteria.where("_id").is(id));
        log.info("-------------->MongoDB find start");
        return this.mongoTemplate.findOne(query, this.getEntityClass(),collectionName);
    }

    /**
     * 根据条件查询集合
     *
     * @param object
     * @return
     */
    public List<T> queryList(T object) {
        Query query = getQueryByObject(object);
        log.info("-------------->MongoDB find start");
        return mongoTemplate.find(query, this.getEntityClass());
    }

    /**
     * 根据条件查询集合
     * @param object
     * @param collectionName
     * @return
     */
    public List<T> queryList(T object,String collectionName) {
        Query query = getQueryByObject(object);
        log.info("-------------->MongoDB find start");
        return mongoTemplate.find(query, this.getEntityClass(),collectionName);
    }

    /**
     * 根据条件查询只返回一个文档
     *
     * @param object
     * @return
     */
    public T queryOne(T object) {
        Query query = getQueryByObject(object);
        log.info("-------------->MongoDB find start");
        return mongoTemplate.findOne(query, this.getEntityClass());
    }

    /**
     * 根据条件查询只返回一个文档
     * @param object
     * @param collectionName
     * @return
     */
    public T queryOne(T object,String collectionName) {
        Query query = getQueryByObject(object);
        log.info("-------------->MongoDB find start");
        return mongoTemplate.findOne(query, this.getEntityClass(),collectionName);
    }

    /**
     * * 根据条件分页查询
     *
     * @param object
     * @param start 查询起始值
     * @param size 查询大小
     * @return
     */
    public List<T> getPage(T object, int start, int size) {
        Query query = getQueryByObject(object);
        if(start >0)
        {
            start --;
        }
        query.skip(start);
        query.limit(size);
        log.info("-------------->MongoDB queryPage start");
        return this.mongoTemplate.find(query, this.getEntityClass());
    }

    /**
     * 根据条件分页查询
     * @param object
     * @param start
     * @param size
     * @param collectionName
     * @return
     */
    public List<T> getPage(T object, int start, int size,String collectionName) {
        Query query = getQueryByObject(object);
        if(start >0)
        {
            start --;
        }
        query.skip(start);
        query.limit(size);
        log.info("-------------->MongoDB queryPage start");
        return this.mongoTemplate.find(query, this.getEntityClass(),collectionName);
    }

    /**
     * * 根据条件查询库中符合条件的记录数量
     *
     * @param object
     * @return
     */
    public Long getCount(T object) {
        Query query = getQueryByObject(object);
        log.info("-------------->MongoDB Count start");
        return this.mongoTemplate.count(query, this.getEntityClass());
    }

    /**
     * 根据条件查询库中符合条件的记录数量
     * @param object
     * @param collectionName
     * @return
     */
    public Long getCount(T object,String collectionName) {
        Query query = getQueryByObject(object);
        log.info("-------------->MongoDB Count start");
        return this.mongoTemplate.count(query, this.getEntityClass(),collectionName);
    }

    /*MongoDB中更新操作分为三种
     * 1:updateFirst     修改第一条
     * 2:updateMulti     修改所有匹配的记录
     * 3:upsert  修改时如果不存在则进行添加操作
     * */

    /**
     * * 删除对象
     *
     * @param t
     * @return
     */
    public int delete(T t) {
        log.info("-------------->MongoDB delete start");
        return (int) this.mongoTemplate.remove(t).getDeletedCount();
    }

    /**
     * 删除对象
     *
     * @param t
     * @param collectionName
     * @return
     */
    public int delete(T t, String collectionName) {
        log.info("-------------->MongoDB delete start");
        return (int) this.mongoTemplate.remove(t, collectionName).getDeletedCount();
    }

    /**
     * 根据id列表批量删除
     * @param ids
     * @return
     */
    public int delete(List<String> ids) {
        Criteria criteria = Criteria.where("_id").in(ids);
        Query query = new Query(criteria);
        return (int) this.mongoTemplate.remove(query, this.getEntityClass()).getDeletedCount();
    }

    /**
     * 根据id列表批量删除
     * @param ids
     * @param collectionName
     * @return
     */
    public int delete(List<String> ids,String collectionName) {
        Criteria criteria = Criteria.where("_id").in(ids);
        Query query = new Query(criteria);
        return (int) this.mongoTemplate.remove(query, this.getEntityClass(),collectionName).getDeletedCount();
    }

    /**
     * 根据id删除
     *
     * @param id
     */
    public void deleteById(String id) {
        Criteria criteria = Criteria.where("_id").is(id);
        if (null != criteria) {
            Query query = new Query(criteria);
            T obj = this.mongoTemplate.findOne(query, this.getEntityClass());
            log.info("-------------->MongoDB deleteById start");
            if (obj != null) {
                this.delete(obj);
            }
        }
    }

    /**
     * 根据id删除
     *
     * @param id
     * @param collectionName 集合名称
     */
    public void deleteById(String id, String collectionName) {
        Criteria criteria = Criteria.where("_id").is(id);
        if (null != criteria) {
            Query query = new Query(criteria);
            T obj = this.mongoTemplate.findOne(query, this.getEntityClass());
            log.info("-------------->MongoDB deleteById start");
            if (obj != null) {
                this.delete(obj, collectionName);
            }
        }
    }

    /**
     * 修改匹配到的第一条记录
     *
     * @param srcObj
     * @param targetObj
     */
    public void updateFirst(T srcObj, T targetObj) {
        Query query = getQueryByObject(srcObj);
        Update update = getUpdateByObject(targetObj);
        log.info("-------------->MongoDB updateFirst start");
        this.mongoTemplate.updateFirst(query, update, this.getEntityClass());
    }

    /**
     * 修改匹配到的第一条记录
     *
     * @param srcObj
     * @param targetObj
     * @param collectionName 集合名称
     */
    public void updateFirst(T srcObj, T targetObj, String collectionName) {
        Query query = getQueryByObject(srcObj);
        Update update = getUpdateByObject(targetObj);
        log.info("-------------->MongoDB updateFirst start");
        this.mongoTemplate.updateFirst(query, update, collectionName);
    }

    /**
     * * 修改匹配到的所有记录
     *
     * @param srcObj
     * @param targetObj
     */
    public void updateMulti(T srcObj, T targetObj) {
        Query query = getQueryByObject(srcObj);
        Update update = getUpdateByObject(targetObj);
        log.info("-------------->MongoDB updateFirst start");
        this.mongoTemplate.updateMulti(query, update, this.getEntityClass());
    }

    /**
     * 修改匹配到的所有记录
     *
     * @param srcObj
     * @param targetObj
     * @param collectionName 集合名称
     */
    public void updateMulti(T srcObj, T targetObj, String collectionName) {
        Query query = getQueryByObject(srcObj);
        Update update = getUpdateByObject(targetObj);
        log.info("-------------->MongoDB updateFirst start");
        this.mongoTemplate.updateMulti(query, update, collectionName);
    }

    /**
     * * 修改匹配到的记录,若不存在该记录则进行添加
     *
     * @param srcObj
     * @param targetObj
     */
    public void updateInsert(T srcObj, T targetObj) {
        Query query = getQueryByObject(srcObj);
        Update update = getUpdateByObject(targetObj);
        log.info("-------------->MongoDB updateInsert start");
        this.mongoTemplate.upsert(query, update, this.getEntityClass());
    }

    /**
     * 修改匹配到的记录,若不存在该记录则进行添加
     *
     * @param srcObj
     * @param targetObj
     * @param collectionName 集合名字
     */
    public void updateInsert(T srcObj, T targetObj, String collectionName) {
        Query query = getQueryByObject(srcObj);
        Update update = getUpdateByObject(targetObj);
        log.info("-------------->MongoDB updateInsert start");
        this.mongoTemplate.upsert(query, update, collectionName);
    }

    /**
     * 将查询条件对象转换为query
     *
     * @param object
     * @return
     * @author Jason
     */
    private Query getQueryByObject(T object) {
        Query query = new Query();
        Map<String, Object> dataMap = BeanUtil.beanToMap(object, false, true);
        Criteria criteria = new Criteria();
        for (Map.Entry<String, Object> entry : dataMap.entrySet()) {
            criteria.and(entry.getKey()).is(entry.getValue());
        }
        query.addCriteria(criteria);
        return query;
    }

    /**
     * 将查询条件对象转换为update
     *
     * @param object
     * @return
     * @author Jason
     */
    private Update getUpdateByObject(T object) {
        Update update = new Update();
        Map<String, Object> dataMap = BeanUtil.beanToMap(object, false, true);
        for (Map.Entry<String, Object> entry : dataMap.entrySet()) {
            update.set(entry.getKey(), entry.getValue());
        }
        return update;
    }
}

增加一个类继承上面的抽象类

import com.telpo.datapushserver.annotation.MongoSwitch;
import com.telpo.datapushserver.entity.mongo.LocationEntity;
import db.BaseMongoDbDao;
import org.springframework.stereotype.Repository;

/**
 * @program: DataPushServer
 * @description: 位置实体Mapper
 * @author: linwl
 * @create: 2020-07-11 15:32
 */
@Repository
public class LocationMapper extends BaseMongoDbDao<LocationEntity> {

  @Override
  protected Class<LocationEntity> getEntityClass() {
    return LocationEntity.class;
  }

  @Override
  @MongoSwitch("basicdata")//使用注解来使用哪个数据库名
  public void save(LocationEntity entity) {
    super.save(entity);
  }

  @Override
  @MongoSwitch("common")
  public void save(LocationEntity entity, String collectionName) {
    super.save(entity, collectionName);
  }
}

效果如下
在这里插入图片描述

你可能感兴趣的:(Springboot,mongodb,java)