最近一段时间公司搞新项目,数据库orm选用了mybatis框架。使用一段时间mybaits后感觉比其他orm框架灵活好用,好处就不说了,网上一搜大把。本次主要讲下mybatis自定义拦截器功能的开发,通过拦截器可以解决项目中蛮多的问题,虽然很多功能不用拦截器也可以实现,但使用自定义拦截器实现功能从我角度至少以下优点(1)灵活,解耦(2)统一控制 ,减少开发工作量,不用散落到每个业务功能点去实现。
一般业务系统项目都涉及到数据权限的控制,此次结合本项目记录下基于mybatis拦截器实现数据权限的过滤,因为项目用到mybatis-plus的分页插件,数据权限拦截过滤的时机也要控制好,在分页拦截器之前先拦截修改sql,不然会导致查询出来的数据同分页统计出来数量不一致。
拦截器基本知识
Mybatis采用责任链模式,通过动态代理组织多个拦截器,通过这些拦截器可以改变mybatis的默认行为,编写自定义拦截器最好了解下它的原理,以便写出安全高效的插件。
(1)拦截器均需要实现org.apache.ibatis.plugin.Interceptor 接口,对于自定义拦截器必须使用mybatis 提供的注解来指明我们要拦截的是四类中的哪一个类接口。
具体规则如下:
a:Intercepts 标识我的类是一个拦截器
b:Signature 则是指明我们的拦截器需要拦截哪一个接口的哪一个方法;type对应四类接口中的某一个,比如是 Executor;method对应接口中的哪类方法,比如 Executor 的 update 方法;args 对应接口中的哪一个方法,比如 Executor 中 query 因为重载原因,方法有多个,args 就是指明参数类型,从而确定是哪一个方法。
@Intercepts({
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})
})
(2) mybatis 拦截器默认可拦截的类型四种,即四种接口类型 Executor、StatementHandler、ParameterHandler 和 ResultSetHandler,对于我们的自定义拦截器必须使用 mybatis 提供的注解来指明我们要拦截的是四类中的哪一个类接口。
(3)拦截器顺序:
不同类型拦截器的顺序Executor -> ParameterHandler -> StatementHandler ->ResultSetHandler
同类型的拦截器的不同对象拦截顺序则根据 mybatis 核心配置文件的配置位置,拦截顺序是 从上往下,在mybatis 核心配置文件中需要配置我们的 plugin
数据权限过滤
1.实现业务需求的数据过滤,在用户访问数据库时进行权限判断并改造sql,达到限制低权限用户访问数据的目的
2.采用技术:mybatis拦截器,java自定义注解,反射,开源jsqlparser
3.核心业务流程图
4.代码实现
(1)创建自定义注解
```
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 数据权限注解
*
*/
@Documented
@Target( value = { ElementType.TYPE, ElementType.METHOD } )
@Retention( RetentionPolicy.RUNTIME )
@Inherited
public @interface DataAuth
{
/**
* 追加sql的方法名
* @return
*/
public String method() default "whereSql";
/**
* 表别名
* @return
*/
public String tableAlias() default "";
}
```
(2)mapper方法增加权限注解
```
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import org.apache.ibatis.annotations.Param;
import java.util.List;
public interface TestMapper extends BaseMapper
/**
* 增加权限注解
*
*/
@DataAuth(tableAlias = "o")
List
}
```
(3)创建自定义拦截器
```
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
import java.io.StringReader;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.Properties;
/**
* 数据权限拦截器
* 根据各个微服务,继承DataAuthService增加不同的where语句
*
*/
@Component
@Intercepts({@Signature(method = "query",type = Executor.class,args = {
MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
)
})
public class MybatisDataAuthInterceptor implements Interceptor,
ApplicationContextAware {
private final static Logger logger = LoggerFactory.getLogger(MybatisDataAuthInterceptor.class);
private static ApplicationContext context;
@Override
public void setApplicationContext(ApplicationContext applicationContext)
throws BeansException {
context = applicationContext;
}
@Override
public Object intercept(Invocation arg0) throws Throwable {
MappedStatement mappedStatement = (MappedStatement) arg0.getArgs()[0];
// 只对查询sql拦截
if (!SqlCommandType.SELECT.equals(mappedStatement.getSqlCommandType())) {
return arg0.proceed();
}
// String mSql = sql;
// 注解逻辑判断 添加注解了才拦截追加
Class> classType = Class.forName(mappedStatement.getId()
.substring(0,
mappedStatement.getId().lastIndexOf(".")));
String mName = mappedStatement.getId()
.substring(mappedStatement.getId()
.lastIndexOf(".") +
1, mappedStatement.getId().length()); //
for (Method method : classType.getDeclaredMethods()) {
if (method.isAnnotationPresent(DataAuth.class) &&
mName.equals(method.getName())) {
/**
* 查找标识了该注解 的实现 类
*/
Map
if ((beanMap != null) && (beanMap.entrySet().size() > 0)) {
for (Map.Entry
DataAuth action = method.getAnnotation(DataAuth.class);
if (StringUtils.isEmpty(action.method())) {
break;
}
try {
Method md = entry.getValue().getClass()
.getMethod(action.method(),
new Class[] { String.class });
/**
* 反射获取业务 sql
*/
String whereSql = (String) md.invoke(context.getBean(
entry.getValue().getClass()),
new Object[] { action.tableAlias() });
if (!StringUtils.isEmpty(whereSql) &&
!"null".equalsIgnoreCase(whereSql)) {
Object parameter = null;
if (arg0.getArgs().length > 1) {
parameter = arg0.getArgs()[1];
}
BoundSql boundSql = mappedStatement.getBoundSql(parameter);
MappedStatement newStatement = newMappedStatement(mappedStatement,
new BoundSqlSqlSource(boundSql));
MetaObject msObject = MetaObject.forObject(newStatement,
new DefaultObjectFactory(),
new DefaultObjectWrapperFactory(),
new DefaultReflectorFactor());
/**
* 通过JSqlParser解析 原有sql,追加sql条件
*/
CCJSqlParserManager parserManager = new CCJSqlParserManager();
Select select = (Select) parserManager.parse(new StringReader(
boundSql.getSql()));
PlainSelect selectBody = (PlainSelect) select.getSelectBody();
Expression whereExpression = CCJSqlParserUtil.parseCondExpression(whereSql);
selectBody.setWhere(new AndExpression(
selectBody.getWhere(),
new Parenthesis(whereExpression)));
/**
* 修改sql
*/
msObject.setValue("sqlSource.boundSql.sql",
selectBody.toString());
arg0.getArgs()[0] = newStatement;
logger.info("Interceptor sql:" +
selectBody.toString());
}
} catch (Exception e) {
logger.error(null, e);
}
break;
}
}
}
}
return arg0.proceed();
}
private MappedStatement newMappedStatement(MappedStatement ms,
SqlSource newSqlSource) {
MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(),
ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if ((ms.getKeyProperties() != null) &&
(ms.getKeyProperties().length != 0)) {
StringBuilder keyProperties = new StringBuilder();
for (String keyProperty : ms.getKeyProperties()) {
keyProperties.append(keyProperty).append(",");
}
keyProperties.delete(keyProperties.length() - 1,
keyProperties.length());
builder.keyProperty(keyProperties.toString());
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
/**
* 当目标类是Executor类型时,才包装目标类,否者直接返回目标本身,减少目标被代理的次数
*/
@Override
public Object plugin(Object target) {
if (target instanceof Executor) {
return Plugin.wrap(target, this);
}
return target;
}
@Override
public void setProperties(Properties arg0) {
// TODO Auto-generated method stub
}
class BoundSqlSqlSource implements SqlSource {
private BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
}
```
(4)增加业务逻辑
```
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.winhong.wincloud.constant.RoleTypeJudge;
import com.winhong.wincore.async.ThreadLocalHolder;
import com.winhong.wincore.user.LoginUserHolder;
import com.winhong.wincore.user.UserInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
@Service
public abstract class AbstractDataAuthService {
private static final Logger LOG = LoggerFactory.getLogger(AbstractDataAuthService.class);
/**
* 默认查询sql,根据角色不同追加不同业务查询条件
*
* @return
*/
public String whereSql(String tableAlias) {
if (!StringUtils.isEmpty(tableAlias)) {
tableAlias = tableAlias + ".";
}
StringBuffer sql = new StringBuffer();
//利用threadlocal获取用户角色信息
UserInfo userInfo = LoginUserHolder.getUser();
// 普通 用户
if (RoleTypeJudge.isNormalUser(userInfo.getRoleTypeCode())) {
sql.append(nomalUserSql(userInfo.getUserUuid(), tableAlias));
}
// 管理员
else if (RoleTypeJudge.isManager(userInfo.getRoleTypeCode())) {
sql.append(managerSql(tableAlias));
} else {
}
return sql.toString();
}
}
```