任务目标:
利用sql拦截器对sql进行二次修改,自动携带租户查询
废话少说,直接上代码
首先注册sql拦截器SqlInterceptor找不到没关系,这是我自己定义的
package com.example.pidog.config;
import com.example.pidog.interceptor.SqlInterceptor;
import org.apache.ibatis.session.SqlSessionFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import javax.annotation.PostConstruct;
import java.util.List;
@Configuration
public class MybatisConfig {
@Autowired
private List sqlSessionFactoryList;
@PostConstruct
public void addMySqlInterceptor() {
SqlInterceptor interceptor = new SqlInterceptor();
for (SqlSessionFactory sqlSessionFactory : sqlSessionFactoryList) {
sqlSessionFactory.getConfiguration().addInterceptor(interceptor);
}
}
}
package com.example.pidog.interceptor;
import cn.hutool.core.collection.CollectionUtil;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.example.pidog.handlers.TenantSqlHandle;
import com.example.pidog.util.StaticMethodGetBean;
import lombok.Setter;
import lombok.experimental.Accessors;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.RowBounds;
import javax.annotation.Resource;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.util.List;
@Setter
@Accessors(chain = true)
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class SqlInterceptor implements Interceptor {
@SuppressWarnings("unchecked")
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = PluginUtils.realTarget(invocation.getTarget());
Object parameter = invocation.getArgs()[1];
System.out.println(parameter);
MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
String id = mappedStatement.getId();
if (filterMethodById(id)) {
BoundSql boundSql = statementHandler.getBoundSql();
String sql = boundSql.getSql();//注意 这里是最终执行的sql 也就是我们接下来需要修改的sql
sql = new TenantSqlHandle().tenantHandle(id,mappedStatement.getSqlCommandType(),sql);
Field field = boundSql.getClass().getDeclaredField("sql");
field.setAccessible(true);
field.set(boundSql, sql);
}
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
if (target instanceof StatementHandler) {
return Plugin.wrap(target, this);
}
return target;
}
/**
* 根据获取到执行 id 找到对应的方法
*
* @param id 根据 MappedStatement 获取到的 id 属性
* @return 是否是 searchByQuery 方法
*/
private boolean filterMethodById(String id) {
String[] split = id.split("\\.");
return CollectionUtil.contains(CollectionUtil.newArrayList("selectById", "updateById"), split[split.length - 1]);
}
}
package com.example.pidog.handlers;
import com.example.pidog.threadLocal.TenantThreadLocal;
import javafx.util.Pair;
import org.apache.ibatis.mapping.SqlCommandType;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
import org.springframework.stereotype.Component;
import java.lang.reflect.Method;
import java.util.Locale;
/**
* 租户sql处理器
*/
public class TenantSqlHandle {
//这里没什么好说的 拼接修改sql字符串就是了
public String tenantHandle(String id, SqlCommandType sqlCommandType, String sql) {
if (TenantThreadLocal.getCurrentDept() == null) {
return sql;
}
Pair dept = TenantThreadLocal.getCurrentDept();
if (sql.contains("WHERE")) {
int whereIndexOf = sql.indexOf("WHERE");
StringBuilder sb = new StringBuilder(sql);
sb.insert(whereIndexOf + "WHERE".length(), " " + dept.getKey() + "=" + dept.getValue() + " and ");
return sb.toString();
}
return sql;
}
}
package com.example.pidog.threadLocal;
import com.example.pidog.domain.User;
import javafx.util.Pair;
public class TenantThreadLocal {
/**
* 构造函数私有
*/
private TenantThreadLocal() {
}
private static final ThreadLocal> TENANT_THREAD_LOCAL =
new ThreadLocal<>();
/**
* 清除租户信息
*/
public static void clear() {
TENANT_THREAD_LOCAL.remove();
}
/**
* 存储当前租户
*/
public static void set(Pair dept) {
TENANT_THREAD_LOCAL.set(dept);
}
/**
* 获取当前租户id
*/
public static Pair getCurrentDept() {
return TENANT_THREAD_LOCAL.get();
}
}
最后将租户的数据保存在线程变量里
接下来去请求拦截器里修改一下,自动从请求头取出租户id,放在线程变量中。
package com.example.pidog.interceptor;
import com.example.pidog.annotation.authority.AuthUtil;
import com.example.pidog.annotation.authority.CrossAuth;
import com.example.pidog.annotation.sql.CrossTenant;
import com.example.pidog.annotation.sql.MultiTenant;
import com.example.pidog.config.JwtConfig;
import com.example.pidog.threadLocal.TenantThreadLocal;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.SignatureException;
import javafx.util.Pair;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.util.Objects;
@Component
public class TokenInterceptor implements HandlerInterceptor {
@Resource
private AuthUtil authUtil;
@Override
public boolean preHandle(HttpServletRequest request,
HttpServletResponse response,
Object handler) throws SignatureException {
Method method = ((HandlerMethod) handler).getMethod();
if (method.isAnnotationPresent(CrossAuth.class)) {
return true;
}
Claims claims = authUtil.examineToken(request);
if (method.isAnnotationPresent(CrossTenant.class)) { //如果存在越过租户注解
return true; //直接越过
}
if (claims.get("tenantId") == null) { //如果租户id为空
throw new SignatureException("参数非法,权限校验失败");//抛出异常
}
TenantThreadLocal.set(new Pair<>("tenantId", claims.get("tenantId").toString())); //向线程变量中添加租户信息
return true;
}
}
这是我的请求拦截器 核心在authUtil中
package com.example.pidog.annotation.authority;
import com.example.pidog.config.JwtConfig;
import com.example.pidog.domain.User;
import com.example.pidog.threadLocal.UserThreadLocal;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.SignatureException;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
@Component
public class AuthUtil {
@Resource
private JwtConfig jwtConfig;
/**
* 验证token合法性
*
* @param request
*/
public Claims examineToken(HttpServletRequest request) {
String token = request.getHeader(jwtConfig.getHeader());
if (StringUtils.isEmpty(token)) {
token = request.getParameter(jwtConfig.getHeader());
}
if (StringUtils.isEmpty(token)) {
throw new SignatureException(jwtConfig.getHeader() + "不能为空");
}
try {
Claims claims = jwtConfig.getTokenClaim(token);
if (claims == null || jwtConfig.isTokenExpired(claims.getExpiration())) {
throw new SignatureException(jwtConfig.getHeader() + "失效,请重新登录。");
}
if (claims.get("userId") == null) {
throw new SignatureException(jwtConfig.getHeader() + "失效,请重新登录。");
}
if (claims.get("nickName") == null) {
throw new SignatureException(jwtConfig.getHeader() + "失效,请重新登录。");
}
User user = new User();
user.setUserId(claims.get("userId", Integer.class).longValue());
user.setNickName(claims.get("nickName", String.class));
if(!jwtConfig.comparisonCache(user.getUserId(), token)){
throw new SignatureException(jwtConfig.getHeader() + "失效,请重新登录。");
};
UserThreadLocal.set(user);
return claims;
} catch (Exception e) {
e.printStackTrace();
throw new SignatureException(jwtConfig.getHeader() + "失效,请重新登录。");
}
}
}