spring boot 中动态增加数据源并根据接口切换数据源(支持shareding-jdbc,mybatis,Aspect)

最近项目需要动态增加数据源,同时数据库跨年跨专题存储,需要支持动态查询和动态增加数据源,前一篇博客没有引入MyBatis,需要手动去拼接SQL,最近将MyBatis引入进来,不用手动拼接SQL一下是整个代码逻辑,代码还没有重构整理,先放出来
第一步:屏蔽Spring boot启动时自动加载数据源

 @SpringBootApplication(exclude={DataSourceAutoConfiguration.class, HibernateJpaAutoConfiguration.class})
public class PotatoDatagaeaAppsApplication {

    public static void main(String[] args) {
        SpringApplication.run(PotatoDatagaeaAppsApplication.class, args);
    }

    @EnableWebSecurity
    public static class SecurityPermitAllConfig extends WebSecurityConfigurerAdapter {

        @Override
        protected void configure(HttpSecurity http) throws Exception {
            http.csrf().disable().authorizeRequests().anyRequest().permitAll().and().httpBasic()
                .disable();
        }
    }
}

第二步:配置自身启动项

@Configuration
public class DynamicDataSourceConfig {

    @Value("${spring.datasource.url}")
    private String defaultDBUrl;
    @Value("${spring.datasource.username}")
    private String defaultDBUser;
    @Value("${spring.datasource.password}")
    private String defaultDBPassword;
    @Value("${spring.datasource.driverClassName}")
    private String defaultDBDreiverName;

    @Bean
    public DynamicDataSource dynamicDataSource() {
        DynamicDataSource dynamicDataSource = DynamicDataSource.getInstance();
        DruidDataSource defaultDataSource = DynamicDataSource
            .createDataSource(defaultDBUrl, defaultDBUser, defaultDBPassword, defaultDBDreiverName);
        Map<Object, Object> map = new HashMap<>();
        map.put("default", defaultDataSource);
        dynamicDataSource.setTargetDataSources(map);
        dynamicDataSource.setDefaultTargetDataSource(defaultDataSource);
        return dynamicDataSource;
    }
}


第三部 配置多数据源多Session管理模块

public class DynamicDataSource extends AbstractRoutingDataSource {

    private static DynamicDataSource instance;
    private static byte[] lock = new byte[0];

    public static synchronized DynamicDataSource getInstance() {
        if (instance == null) {
            synchronized (lock) {
                if (instance == null) {
                    instance = new DynamicDataSource();
                }
            }
        }
        return instance;
    }

    private static Map<Object, Object> dataSourceMap = new HashMap<Object, Object>();

    private static Map<Object, SqlSessionFactory> targetSqlSessionFactorys = new HashMap<>();

    @Resource
    private DataSourceCache dataSourceCache; // redis 缓存信息
    @Resource
    private RdsLinkRemote rdsLinkRemote;

    @Override
    protected Object determineCurrentLookupKey() {
        logger.info("concurrent dynamic data source " + DataSourceContext.getDataSource());
        return DataSourceContext.getDataSource();
    }

    @Override
    public void setTargetDataSources(Map<Object, Object> targetDataSources) {
        dataSourceMap.putAll(targetDataSources);
        super.setTargetDataSources(dataSourceMap);
        super.afterPropertiesSet();// 必须添加该句,否则新添加数据源无法识别到
    }

    public DruidDataSource getDataSource(String dataSourceName) {
        return (DruidDataSource) DynamicDataSource.dataSourceMap.get(dataSourceName);
    }
// 创建Session
    public SqlSessionFactory getSqlSessionFactory(String dataSourceName) throws Exception {
        if (null == targetSqlSessionFactorys.get(dataSourceName)) {
            if (null != getDataSource(dataSourceName)) {
                SqlSessionFactoryBean bean = new SqlSessionFactoryBean();
                bean.setDataSource(getDataSource(dataSourceName));
                bean.setMapperLocations(new PathMatchingResourcePatternResolver()
                    .getResources("classpath*:mappers/*.xml"));
                targetSqlSessionFactorys.put(dataSourceName, bean.getObject());
            }
        }
        return targetSqlSessionFactorys.get(dataSourceName);
    }

    public void clear() {
        DynamicDataSource.dataSourceMap.clear();
    }

    public static DruidDataSource createDataSource(String url, String userName, String Password,
        String driverName) {
        DruidDataSource druidDataSource = new DruidDataSource();
        druidDataSource.setUrl(url);
        druidDataSource.setUsername(userName);
        druidDataSource.setPassword(Password);
        druidDataSource.setDriverClassName(driverName);
        druidDataSource.setMaxPoolPreparedStatementPerConnectionSize(30);
        return druidDataSource;
    }

    public RdsLinkDetailResp getRdsLinkInfo(String rdsCode) {
        APIResponse<RdsLinkDetailResp> rdsinfoResp = rdsLinkRemote
            .detailByCode(rdsCode, PrincipalInfoContext.getAuthorization(),
                PrincipalInfoContext.getReqId());
        if (BizCommMessage.OK != rdsinfoResp.getCode()) {
            return null;
        }
        return rdsinfoResp.getData();
    }

    public void checkDataSource(RdsLinkDetailResp resp) throws Exception {
        DruidDataSource data = createDataSource(resp.getJdbcUrl(), resp.getUsername(),
            resp.getPassword(), resp.getRdsDriver());
        Connection conn = data.getConnection();
        conn.close();
        DruidDataSource finalData = data;
        setTargetDataSources(new HashMap<Object, Object>() {{
            put(resp.getCode(), (Object) finalData);
        }});

    }


    public void createDataSource(String code) {
        RdsLinkDetailResp rdsinfo = null;
        DruidDataSource data = getDataSource(code);
        try {
            Connection conn = data.getConnection();
            conn.close();
        } catch (Exception ex) {
            try {
                rdsinfo = dataSourceCache.getRDSLinkInfo(code);
                checkDataSource(rdsinfo);
            } catch (Exception inniex) {
                try {
                    rdsinfo = getRdsLinkInfo(code);
                    dataSourceCache.setRDSLinkInfo(rdsinfo);
                    checkDataSource(rdsinfo);
                } catch (Exception inputex) {
                    logger.error("createDataSource error:" + inputex.getMessage());
                }
            }
        }
    }
//--------创建分库分表查询数据源(shareding-jdbc datasource)
    public String createActualDataNodes(String rdsCode, String tableName,
        List<String> areaCodes) {
        String tables = "";
        RdsLinkDetailResp rdsinfo = dataSourceCache.getRDSLinkInfo(rdsCode);
        for (String areeCode : areaCodes) {
            tables += String
                .format("%s.%s_%s,", rdsinfo.getRdsDbName(), tableName, areeCode);
        }
        return tables;
    }

    public void createsharedDataSource(String rdsCode, String tableName,
        List<String> areaCodes) {
        try {
            createDataSource(rdsCode);
            RdsLinkDetailResp rdsinfo = dataSourceCache.getRDSLinkInfo(rdsCode);
            Map<String, DataSource> tempdataSourceMap = new HashMap<>();
            tempdataSourceMap
                .put(rdsinfo.getRdsDbName(), (DataSource) dataSourceMap.get(rdsinfo.getCode()));
            String tables = createActualDataNodes(rdsCode, tableName, areaCodes);
            ShardingRuleConfiguration shardingRuleConfig = new ShardingRuleConfiguration();
            if (null == dataSourceMap.get(tables)) {
                TableRuleConfiguration orderTableRuleConfig = new TableRuleConfiguration();
                // 期望的基础表名
                orderTableRuleConfig.setLogicTable(tableName);
                orderTableRuleConfig.setActualDataNodes(tables);
                shardingRuleConfig.getTableRuleConfigs().add(orderTableRuleConfig);
                DataSource dataSource = ShardingDataSourceFactory
                    .createDataSource(tempdataSourceMap, shardingRuleConfig, new HashMap<>(),
                        new Properties());
                String finalTables = tables;
                setTargetDataSources(new HashMap<Object, Object>() {{
                    put(finalTables, dataSource);
                }});
            }
        } catch (Exception ex) {
            ex.printStackTrace();
            logger.error("createsharedDataSource error:" + ex.getMessage());
        }
    }
}

线程内不切换数据源的使用办法

@Service
public class DLTBRepoImpl implements DLTBRepo {

    @Autowired
    private DLTBMapper dltbMapper;
    @Resource
    private DynamicDataSource dynamicDataSource;


    @Resource
    private AreaCodeRepo AreaCodeMapper;
    //动态切换数据源
    @Override
    public void commonInit(String rdsCode, String tableName, List<String> areaCodes) {
        dynamicDataSource
            .createsharedDataSource(rdsCode, tableName, areaCodes);
        String datasourceKey = dynamicDataSource
            .createActualDataNodes(rdsCode, tableName, areaCodes);
        DataSourceContext.setDataSource(datasourceKey);
    }

    @Override
    public List<Map<String, Object>> selectstatisticsDlbm(QuaryStatisticDLTBByAreaCode o) {
        return dltbMapper.selectstatisticsDlbm(o);
    }

    @Override
    public int selectCountByPage(GeoQuaryReq o) {
        //commonInit(o);
        return dltbMapper.selectCountByPage(o);
    }

    @Override
    public List<Map<String, Object>> selectListByPage(GeoQuaryReq o) {
        //commonInit(o);
        return dltbMapper.selectListByPage(o);
    }
}



@Slf4j
@Service
public class TLSAreaCodeBiz {

    @Resource
    private AreaCodeRepo areaCodeRepo;
    @Resource
    private DictManagerBiz dictManagerBiz;


    public JSONObject getAreaCodeTree(int beginlevel, int endlevel, String areaCode) {
        JSONObject resultData = new JSONObject();
        List<String> tables = Arrays.asList("", "province", "city", "county", "villages");
        String rdsCode = dictManagerBiz.getCommonDbInfo();
        areaCodeRepo.commonInit(rdsCode);
        List<AreaCodeEntity> tempdatas = new ArrayList<>();
        for (int index = beginlevel; index <= endlevel; index++) {
            int isTop = 0;
            if (beginlevel == index) {
                isTop = 1;
            }
            List<AreaCodeEntity> tempdata = areaCodeRepo
                .selectTree(tables.get(index), index, isTop, areaCode);
            tempdatas.addAll(tempdata);
        }
        List<AreaCodeEntity> arrayDatas = new ArrayList<>();
        for (AreaCodeEntity item : tempdatas) {
            AreaCodeEntity temp = new AreaCodeEntity();
            temp.setCode(item.getCode());
            temp.setName(item.getName());
            temp.setParentCode(item.getParentCode());
            arrayDatas.add(temp);
        }
        resultData.put("array", arrayDatas);
        resultData.put("tree", AreaCodeEntity.listToTree(tempdatas));
        return resultData;
    }

    public List<AreaCodeEntity> getAreaCode(AreaCodeReq areaCodeReq) {
        return new BizTemplate<List<AreaCodeEntity>>() {
            String rdsCode = "";
            String tableName = "";

            @Override
            protected void checkParams() {
                rdsCode = dictManagerBiz.getCommonDbInfo();
                tableName = "county";
                if (null != areaCodeReq.getAreaCodeLevel()) {
                    if (XzqLevelType.PROVINCE.getLevel() == areaCodeReq.getAreaCodeLevel()) {
                        tableName = "province";
                    }
                    if (XzqLevelType.CITY.getLevel() == areaCodeReq.getAreaCodeLevel()) {
                        tableName = "city";
                    }
                    if (XzqLevelType.COUNTY.getLevel() == areaCodeReq.getAreaCodeLevel()) {
                        tableName = "county";
                    }
                    if (XzqLevelType.TOWN.getLevel() == areaCodeReq.getAreaCodeLevel()) {
                        tableName = "villages";
                    }
                }
                areaCodeRepo.commonInit(rdsCode);
            }

            @Override
            protected List<AreaCodeEntity> process() {
                return areaCodeRepo
                    .selectAll(tableName, rdsCode, areaCodeReq.getCurrentAreaCodeLevel(),
                        areaCodeReq.getAreaCode(), areaCodeReq.getWktGeom());
            }
        }.execute();
    }

    public PageResult<Map<String, Object>> quaeryByGeo(PoiQuaryReq req) {
        return new BizTemplate<PageResult<Map<String, Object>>>() {
            String rdsCode = "";
            String tableName = "poi";

            @Override
            protected void checkParams() {
                rdsCode = dictManagerBiz.getCommonDbInfo();
                areaCodeRepo.commonInit(rdsCode);
                req.setTableName(tableName);
            }

            @Override
            protected PageResult<Map<String, Object>> process() {
                if (req.getPageNo() >= 1) {
                    req.setOffset((req.getPageNo() - 1) * req.getPageSize());
                }
                PageResult<Map<String, Object>> pageResult = new PageResult<>();
                int total = areaCodeRepo.selectCountByPage(req);
                pageResult.setTotal(total);
                if (total > 0) {
                    List<Map<String, Object>> list = areaCodeRepo.selectListByPage(req);
                    pageResult.setList(list);
                }
                return pageResult;
            }
        }.execute();

    }
}

同一线程内切换两种数据源调用Mybatis 查询数据

@Service
public class AreaCodeRepoImpl implements AreaCodeRepo {

    @Autowired
    private AreaCodeMapper areaCodeMapper;
    @Resource
    private DynamicDataSource dynamicDataSource;

    @Override
    public List<AreaCodeEntity> selectAll(String tablename, String rdsCode, Integer areaCodeLevel,
        String areaCode, String wktGeom) {
        try {
            return areaCodeMapper.selectAll(tablename, areaCodeLevel, areaCode, wktGeom, 1);
        } catch (Exception ex) {
            ex.printStackTrace();
        }
        return null;
    }

    @Override
    public void commonInit(String rdsCode) {
        DataSourceContext.setDataSource(rdsCode);
        dynamicDataSource.createDataSource(rdsCode);
    }

    @Override
    public int selectCountByPage(PoiQuaryReq o) {
        return areaCodeMapper.selectCountByPage(o);
    }

    @Override
    public List<Map<String, Object>> selectListByPage(PoiQuaryReq o) {
        return areaCodeMapper.selectListByPage(o);
    }
	// 此处动态获取数据源再获取Session执行查找,线程内置切换一次数据源,其他数据源通过静态数据源去查找
    @Override
    public List<AreaCodeEntity> selectAllByNotSwitch(String tablename, String rdsCode,
        Integer areaCodeLevel, String areaCode, String wktGeom) throws SQLException {
        List<AreaCodeEntity> resultData = new ArrayList<>();
        dynamicDataSource.createDataSource(rdsCode);
        SqlSession sqlSession = null;
        try {
            sqlSession = dynamicDataSource.getSqlSessionFactory(rdsCode).openSession();
        } catch (Exception e) {
            e.printStackTrace();
        }
        AreaCodeMapper mapper = sqlSession.getMapper(AreaCodeMapper.class);
        resultData = mapper.selectAll(tablename, areaCodeLevel, areaCode, wktGeom, 0);
        sqlSession.close();
        return resultData;
    }

    @Override
    public List<AreaCodeEntity> selectTree(String tablename, Integer beginlevel, Integer itTop,
        String areaCode) {
        return areaCodeMapper.selectTree(tablename, beginlevel, itTop, areaCode);
    }
}

面向切面拦截数据源信息,切换数据源

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Mapping
@Documented
public @interface DBAspect {
}

@Aspect
@Order(-1)
@Component
public class DynamicDataSourceAspect {

    @Autowired
    private DynamicDataSource dynamicDataSource;
    private static final Logger logger = LoggerFactory.getLogger(DynamicDataSourceAspect.class);

    @Before(value = "@annotation(com.tudou.potato.datagaea.apps.cache.DBAspect)")
    public void switchDataSource(JoinPoint point) {
        Object[] args = point.getArgs();  //获取目标方法的入参
        for (int i = 0; i < point.getArgs().length; i++) {
            System.out.println(point.getArgs()[i]);
            if (point.getArgs()[i].toString().contains("RDS.")) {
                DataSourceContext.setDataSource(point.getArgs()[i].toString());
                dynamicDataSource.createDataSource(point.getArgs()[i].toString());
                logger.error(point.getArgs()[i].toString());
            }
        }
    }

}

代码未整理,不忙的时候重新设计下,再发一篇,目前只是够满足项目上使用,没时间整理代码

你可能感兴趣的:(spring,boot,spring,cloud)