mybatis自定义拦截器-数据权限过滤

   最近一段时间公司搞新项目,数据库orm选用了mybatis框架。使用一段时间mybaits后感觉比其他orm框架灵活好用,好处就不说了,网上一搜大把。本次主要讲下mybatis自定义拦截器功能的开发,通过拦截器可以解决项目中蛮多的问题,虽然很多功能不用拦截器也可以实现,但使用自定义拦截器实现功能从我角度至少以下优点(1)灵活,解耦(2)统一控制 ,减少开发工作量,不用散落到每个业务功能点去实现。

    一般业务系统项目都涉及到数据权限的控制,此次结合本项目记录下基于mybatis拦截器实现数据权限的过滤,因为项目用到mybatis-plus的分页插件,数据权限拦截过滤的时机也要控制好,在分页拦截器之前先拦截修改sql,不然会导致查询出来的数据同分页统计出来数量不一致。


拦截器基本知识


    Mybatis采用责任链模式,通过动态代理组织多个拦截器,通过这些拦截器可以改变mybatis的默认行为,编写自定义拦截器最好了解下它的原理,以便写出安全高效的插件。

 (1)拦截器均需要实现org.apache.ibatis.plugin.Interceptor 接口,对于自定义拦截器必须使用mybatis 提供的注解来指明我们要拦截的是四类中的哪一个类接口。

具体规则如下:

 a:Intercepts 标识我的类是一个拦截器

 b:Signature 则是指明我们的拦截器需要拦截哪一个接口的哪一个方法;type对应四类接口中的某一个,比如是 Executor;method对应接口中的哪类方法,比如 Executor 的 update 方法;args 对应接口中的哪一个方法,比如 Executor 中 query 因为重载原因,方法有多个,args 就是指明参数类型,从而确定是哪一个方法。

@Intercepts({

    @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})

})

(2) mybatis 拦截器默认可拦截的类型四种,即四种接口类型 Executor、StatementHandler、ParameterHandler 和 ResultSetHandler,对于我们的自定义拦截器必须使用 mybatis 提供的注解来指明我们要拦截的是四类中的哪一个类接口。

(3)拦截器顺序:

 不同类型拦截器的顺序Executor -> ParameterHandler -> StatementHandler ->ResultSetHandler

  同类型的拦截器的不同对象拦截顺序则根据 mybatis 核心配置文件的配置位置,拦截顺序是 从上往下,在mybatis 核心配置文件中需要配置我们的 plugin 


数据权限过滤


   1.实现业务需求的数据过滤,在用户访问数据库时进行权限判断并改造sql,达到限制低权限用户访问数据的目的

   2.采用技术:mybatis拦截器,java自定义注解,反射,开源jsqlparser

   3.核心业务流程图


4.代码实现

(1)创建自定义注解

```

import java.lang.annotation.Documented;

import java.lang.annotation.ElementType;

import java.lang.annotation.Inherited;

import java.lang.annotation.Retention;

import java.lang.annotation.RetentionPolicy;

import java.lang.annotation.Target;

/**

* 数据权限注解

*

*/

@Documented

@Target( value = { ElementType.TYPE, ElementType.METHOD } )

@Retention( RetentionPolicy.RUNTIME )

@Inherited

public @interface DataAuth

{

/**

* 追加sql的方法名

* @return

*/

public String method() default "whereSql";

/**

* 表别名

* @return

*/

public String tableAlias() default "";

}

```

(2)mapper方法增加权限注解

```

import com.baomidou.mybatisplus.core.mapper.BaseMapper;

import com.baomidou.mybatisplus.extension.plugins.pagination.Page;

import org.apache.ibatis.annotations.Param;

import java.util.List;

public interface TestMapper extends BaseMapper {

    /**

    * 增加权限注解

    *

    */

    @DataAuth(tableAlias = "o")

    List listData(TestQuery testQuery);

}

```

(3)创建自定义拦截器

```

import net.sf.jsqlparser.expression.Expression;

import net.sf.jsqlparser.expression.Parenthesis;

import net.sf.jsqlparser.expression.operators.conditional.AndExpression;

import net.sf.jsqlparser.parser.CCJSqlParserManager;

import net.sf.jsqlparser.parser.CCJSqlParserUtil;

import net.sf.jsqlparser.statement.select.PlainSelect;

import net.sf.jsqlparser.statement.select.Select;

import org.apache.commons.lang3.StringUtils;

import org.apache.ibatis.executor.Executor;

import org.apache.ibatis.mapping.BoundSql;

import org.apache.ibatis.mapping.MappedStatement;

import org.apache.ibatis.mapping.SqlCommandType;

import org.apache.ibatis.mapping.SqlSource;

import org.apache.ibatis.plugin.Interceptor;

import org.apache.ibatis.plugin.Intercepts;

import org.apache.ibatis.plugin.Invocation;

import org.apache.ibatis.plugin.Plugin;

import org.apache.ibatis.plugin.Signature;

import org.apache.ibatis.reflection.DefaultReflectorFactory;

import org.apache.ibatis.reflection.MetaObject;

import org.apache.ibatis.reflection.factory.DefaultObjectFactory;

import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;

import org.apache.ibatis.session.ResultHandler;

import org.apache.ibatis.session.RowBounds;

import org.slf4j.Logger;

import org.slf4j.LoggerFactory;

import org.springframework.beans.BeansException;

import org.springframework.context.ApplicationContext;

import org.springframework.context.ApplicationContextAware;

import org.springframework.stereotype.Component;

import java.io.StringReader;

import java.lang.reflect.Method;

import java.util.Map;

import java.util.Properties;

/**

* 数据权限拦截器

* 根据各个微服务,继承DataAuthService增加不同的where语句

*

*/

@Component

@Intercepts({@Signature(method = "query",type = Executor.class,args =  {

        MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}

    )

})

public class MybatisDataAuthInterceptor implements Interceptor,

    ApplicationContextAware {

    private final static Logger logger = LoggerFactory.getLogger(MybatisDataAuthInterceptor.class);

    private static ApplicationContext context;

    @Override

    public void setApplicationContext(ApplicationContext applicationContext)

        throws BeansException {

        context = applicationContext;

    }

    @Override

    public Object intercept(Invocation arg0) throws Throwable {

        MappedStatement mappedStatement = (MappedStatement) arg0.getArgs()[0];

        // 只对查询sql拦截

        if (!SqlCommandType.SELECT.equals(mappedStatement.getSqlCommandType())) {

            return arg0.proceed();

        }

        // String mSql = sql;

        // 注解逻辑判断 添加注解了才拦截追加

        Class classType = Class.forName(mappedStatement.getId()

                                                          .substring(0,

                    mappedStatement.getId().lastIndexOf(".")));

        String mName = mappedStatement.getId()

                                      .substring(mappedStatement.getId()

                                                                .lastIndexOf(".") +

                1, mappedStatement.getId().length()); //

        for (Method method : classType.getDeclaredMethods()) {

            if (method.isAnnotationPresent(DataAuth.class) &&

                    mName.equals(method.getName())) {

                /**

                * 查找标识了该注解 的实现 类

                */

                Map beanMap = context.getBeansWithAnnotation(DataAuth.class);

                if ((beanMap != null) && (beanMap.entrySet().size() > 0)) {

                    for (Map.Entry entry : beanMap.entrySet()) {

                        DataAuth action = method.getAnnotation(DataAuth.class);

                        if (StringUtils.isEmpty(action.method())) {

                            break;

                        }

                        try {

                            Method md = entry.getValue().getClass()

                                            .getMethod(action.method(),

                                    new Class[] { String.class });

                            /**

                            * 反射获取业务 sql

                            */

                            String whereSql = (String) md.invoke(context.getBean(

                                        entry.getValue().getClass()),

                                    new Object[] { action.tableAlias() });

                            if (!StringUtils.isEmpty(whereSql) &&

                                    !"null".equalsIgnoreCase(whereSql)) {

                                Object parameter = null;

                                if (arg0.getArgs().length > 1) {

                                    parameter = arg0.getArgs()[1];

                                }

                                BoundSql boundSql = mappedStatement.getBoundSql(parameter);

                                MappedStatement newStatement = newMappedStatement(mappedStatement,

                                        new BoundSqlSqlSource(boundSql));

                                MetaObject msObject = MetaObject.forObject(newStatement,

                                        new DefaultObjectFactory(),

                                        new DefaultObjectWrapperFactory(),

                                        new DefaultReflectorFactor());

                                /**

                                * 通过JSqlParser解析 原有sql,追加sql条件

                                */

                                CCJSqlParserManager parserManager = new CCJSqlParserManager();

                                Select select = (Select) parserManager.parse(new StringReader(

                                            boundSql.getSql()));

                                PlainSelect selectBody = (PlainSelect) select.getSelectBody();

                                Expression whereExpression = CCJSqlParserUtil.parseCondExpression(whereSql);

                                selectBody.setWhere(new AndExpression(

                                        selectBody.getWhere(),

                                        new Parenthesis(whereExpression)));

                                /**

                                * 修改sql

                                */

                                msObject.setValue("sqlSource.boundSql.sql",

                                    selectBody.toString());

                                arg0.getArgs()[0] = newStatement;

                                logger.info("Interceptor sql:" +

                                    selectBody.toString());

                            }

                        } catch (Exception e) {

                            logger.error(null, e);

                        }

                        break;

                    }

                }

            }

        }

        return arg0.proceed();

    }

    private MappedStatement newMappedStatement(MappedStatement ms,

        SqlSource newSqlSource) {

        MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(),

                ms.getId(), newSqlSource, ms.getSqlCommandType());

        builder.resource(ms.getResource());

        builder.fetchSize(ms.getFetchSize());

        builder.statementType(ms.getStatementType());

        builder.keyGenerator(ms.getKeyGenerator());

        if ((ms.getKeyProperties() != null) &&

                (ms.getKeyProperties().length != 0)) {

            StringBuilder keyProperties = new StringBuilder();

            for (String keyProperty : ms.getKeyProperties()) {

                keyProperties.append(keyProperty).append(",");

            }

            keyProperties.delete(keyProperties.length() - 1,

                keyProperties.length());

            builder.keyProperty(keyProperties.toString());

        }

        builder.timeout(ms.getTimeout());

        builder.parameterMap(ms.getParameterMap());

        builder.resultMaps(ms.getResultMaps());

        builder.resultSetType(ms.getResultSetType());

        builder.cache(ms.getCache());

        builder.flushCacheRequired(ms.isFlushCacheRequired());

        builder.useCache(ms.isUseCache());

        return builder.build();

    }

    /**

    * 当目标类是Executor类型时,才包装目标类,否者直接返回目标本身,减少目标被代理的次数

    */

    @Override

    public Object plugin(Object target) {

        if (target instanceof Executor) {

            return Plugin.wrap(target, this);

        }

        return target;

    }

    @Override

    public void setProperties(Properties arg0) {

        // TODO Auto-generated method stub

    }

    class BoundSqlSqlSource implements SqlSource {

        private BoundSql boundSql;

        public BoundSqlSqlSource(BoundSql boundSql) {

            this.boundSql = boundSql;

        }

        @Override

        public BoundSql getBoundSql(Object parameterObject) {

            return boundSql;

        }

    }

}

```

(4)增加业务逻辑

```

import com.baomidou.mybatisplus.core.toolkit.StringUtils;

import com.winhong.wincloud.constant.RoleTypeJudge;

import com.winhong.wincore.async.ThreadLocalHolder;

import com.winhong.wincore.user.LoginUserHolder;

import com.winhong.wincore.user.UserInfo;

import org.slf4j.Logger;

import org.slf4j.LoggerFactory;

import org.springframework.stereotype.Service;

@Service

public abstract class AbstractDataAuthService {

    private static final Logger LOG = LoggerFactory.getLogger(AbstractDataAuthService.class);

    /**

    * 默认查询sql,根据角色不同追加不同业务查询条件

    *

    * @return

    */

    public String whereSql(String tableAlias) {

        if (!StringUtils.isEmpty(tableAlias)) {

            tableAlias = tableAlias + ".";

        }

        StringBuffer sql = new StringBuffer();

        //利用threadlocal获取用户角色信息

        UserInfo userInfo = LoginUserHolder.getUser();

        // 普通 用户

        if (RoleTypeJudge.isNormalUser(userInfo.getRoleTypeCode())) {

            sql.append(nomalUserSql(userInfo.getUserUuid(), tableAlias));

        }

        // 管理员

        else if (RoleTypeJudge.isManager(userInfo.getRoleTypeCode())) {

            sql.append(managerSql(tableAlias));

        } else {

        }

        return sql.toString();

    }

}

```

你可能感兴趣的:(mybatis自定义拦截器-数据权限过滤)