数据权限实现(Mybatis拦截器+JSqlParser)

由于本人才疏学浅,刚刚入门。本文章是我在实现数据权限的过程中的学习体会。

总体思想

数据权限实现(Mybatis拦截器+JSqlParser)_第1张图片

一、Mybatis拦截器

参考:

Mybatis中文官网

慕课网Mybatis方面视频

SQL解析

引用官网说明:

MyBatis 允许你在已映射语句执行过程中的某一点进行拦截调用。默认情况下,MyBatis 允许使用插件来拦截的方法调用包括:

Executor(update, query, flushStatements, commit, rollback, getTransaction, close, isClosed)

ParameterHandler(getParameterObject, setParameters)

ResultSetHandler(handleResultSets, handleOutputParameters)

StatementHandler(prepare, parameterize, batch, update, query)

通过 MyBatis 提供的强大机制,使用插件是非常简单的,只需实现 Interceptor 接口,并指定了想要拦截的方法签名即可。

Mybatis所提供的功能是Plugin,虽然应译为插件,但是实质就是指的我们所需要使用的拦截器。

方法及参数解析:

1. Interceptor 接口

public interface Interceptor {

  Object intercept(Invocation invocation) throws Throwable;

  Object plugin(Object target);

  void setProperties(Properties properties);

}

实现 Interceptor 接口也就是实现intercept,plugin,setProperties这三个方法,其中

intercept方法是我们拦截到对象后所进行操作的位置,也就是我们之后编写逻辑代码的位置。

plugin方法,根据参数可以看出,该方法的作用是拦截我们需要拦截到的对象。

setProperties方法,我们可以通过配置文件中进行properties配置,然后在该方法中读取到配置。

这三个方法的执行顺序: setProperties--->plugin--->intercept

2.intercept方法中的Invocation类的属性

 private Object target;	//所拦截到的目标的代理
 private Method method;	//所拦截目标的具体方法
 private Object[] args;	//方法的参数


实现interceptor接口

@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }) })
public class MyInterceptor implements Interceptor {

	@Override
	public Object intercept(Invocation invocation) throws Throwable {
		//逻辑代码区
		return invocation.proceed();
	}

	@Override
	public Object plugin(Object target) {
		//生成代理对象
		return Plugin.wrap(target, this);
	}

	@Override
	public void setProperties(Properties properties) {
	}

}

解释:

@intercepts声明该类为拦截器,@signature声明拦截对象。

Mybatis获取Statement是在statementHandler中,因为我们需要拦截的对象应该是Statement,StatementHandler类中有返回值为StatementPrepare方法,所以,这个类就是我们需要拦截的对象。

method为我们需要拦截的prepare方法,type为所要拦截的接口类,argsprepare方法的参数。




源码解析:

StatementHandler源码:

public interface StatementHandler {

  Statement prepare(Connection connection)
      throws SQLException;

  void parameterize(Statement statement)
      throws SQLException;

  void batch(Statement statement)
      throws SQLException;

  int update(Statement statement)
      throws SQLException;

   List query(Statement statement, ResultHandler resultHandler)
      throws SQLException;

  BoundSql getBoundSql();

  ParameterHandler getParameterHandler();

}

该源码中的prepare方法为我们需要的拦截的,它的实现为:


实际的实现方法在BaseStatementHandler中:

@Override
  public Statement prepare(Connection connection) throws SQLException {
    ErrorContext.instance().sql(boundSql.getSql());
    Statement statement = null;
    try {
      statement = instantiateStatement(connection);//<-----也就是这个方法
      setStatementTimeout(statement);
      setFetchSize(statement);
      return statement;
    } catch (SQLException e) {
      closeStatement(statement);
      throw e;
    } catch (Exception e) {
      closeStatement(statement);
      throw new ExecutorException("Error preparing statement.  Cause: " + e, e);
    }
  }
protected abstract Statement instantiateStatement(Connection connection) throws SQLException;
该方法为抽象方法,它的实现为


由于我们的是预编译的sql,所以就是PreparedStatementHandler类中的实现方法

 @Override
  protected Statement instantiateStatement(Connection connection) throws SQLException {
    String sql = boundSql.getSql();//<----这就是我们的sql语句
    if (mappedStatement.getKeyGenerator() instanceof Jdbc3KeyGenerator) {
      String[] keyColumnNames = mappedStatement.getKeyColumns();
      if (keyColumnNames == null) {
        return connection.prepareStatement(sql, PreparedStatement.RETURN_GENERATED_KEYS);
      } else {
        return connection.prepareStatement(sql, keyColumnNames);
      }
    } else if (mappedStatement.getResultSetType() != null) {
      return connection.prepareStatement(sql, mappedStatement.getResultSetType().getValue(), ResultSet.CONCUR_READ_ONLY);
    } else {
      return connection.prepareStatement(sql);
    }
  }

已经理清了sql的执行逻辑,就可以对拦截到的statementHandler进行操作了。

@Override
	public Object intercept(Invocation invocation) throws Throwable {
		StatementHandler handler = (StatementHandler)invocation.getTarget();
		//由于mappedStatement中有我们需要的方法id,但却是protected的,所以要通过反射获取
		MetaObject statementHandler = SystemMetaObject.forObject(handler);
		MappedStatement mappedStatement = (MappedStatement) statementHandler.getValue("delegate.mappedStatement");
		//获取sql
		BoundSql boundSql = handler.getBoundSql();
		String sql = boundSql.getSql();
		//获取方法id
		String id = mappedStatement.getId();
		if ("需要增强的方法的id".equals(id)) {
			//增强sql代码块
		}
		return invocation.proceed();
	}


在以上操作完成之后不要忘了注册该拦截器


    
       
    
好了,到此Mybatis拦截器的编写以及配置就到此结束,接下来需要做的就是sql解析方面(JSqlParser)的学习了.

二、JSqlParser

GitHub

1.在项目添加jsqlparser依赖


	com.github.jsqlparser
	jsqlparser
	1.0

2.解析sql

先判断sql语句的类型(SELECT,UPDATE,INSERT,DELETE.....)
根据语句类型将sql转化成相应对象

CCJSqlParserManager parserManager = new CCJSqlParserManager();
if ("SELECT".equals(sqlCommandType)) {
				Select select = (Select)parserManager.parse(new StringReader(sql));
			}
 
  

 
  3.访问各个接口实现类( 
  SelectVisitorImpl为自己实现 
  SelectVisitor的实现类) 
  

总体思想就是将sql语句分割成很多个小部分然后去访问各个visitor实现类.

select.getSelectBody().accept(new SelectVisitorImpl());

SelectVisitorImpl.class:

package com.test.sqlparser.visitor;

import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.OrderByElement;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectVisitor;
import net.sf.jsqlparser.statement.select.SetOperationList;
import net.sf.jsqlparser.statement.select.WithItem;

public class SelectVisitorImpl implements SelectVisitor {
	// 主要工作就是实现各种底层visitor,然后在解析的时候添加条件

	// 正常的select,也就是包含全部属性的select
	@Override
	public void visit(PlainSelect plainSelect) {

		// 访问 select
		if (plainSelect.getSelectItems() != null) {
			for (SelectItem item : plainSelect.getSelectItems()) {
				item.accept(new SelectItemVisitorImpl());
			}
		}

		// 访问from
		FromItem fromItem = plainSelect.getFromItem();
		FromItemVisitorImpl fromItemVisitorImpl = new FromItemVisitorImpl();
		fromItem.accept(fromItemVisitorImpl);

		// 访问where
		if (plainSelect.getWhere() != null) {
			plainSelect.getWhere().accept(new ExpressionVisitorImpl());
		}

		//过滤增强的条件  
        if (fromItemVisitorImpl.getEnhancedCondition() != null) {  
            if (plainSelect.getWhere() != null) {  
                Expression expr = new Parenthesis(plainSelect.getWhere()); 
                Expression enhancedCondition =  new Parenthesis(fromItemVisitorImpl.getEnhancedCondition()); 
                AndExpression and = new AndExpression(enhancedCondition, expr);  
                plainSelect.setWhere(and);  
            } else {  
            	plainSelect.setWhere(fromItemVisitorImpl.getEnhancedCondition());  
            }  
        }  
		
		// 访问join
		if (plainSelect.getJoins() != null) {
			for (Join join : plainSelect.getJoins()) {
				join.getRightItem().accept(new FromItemVisitorImpl());
			}
		}
		
		// 访问 order by
		if (plainSelect.getOrderByElements() != null) {
			for (OrderByElement orderByElement : plainSelect
					.getOrderByElements()) {
				orderByElement.getExpression().accept(
						new ExpressionVisitorImpl());
			}
		}

		// 访问group by having
		if (plainSelect.getHaving() != null) {
			plainSelect.getHaving().accept(new ExpressionVisitorImpl());
		}

	}

	// set操作列表
	@Override
	public void visit(SetOperationList setOpList) {
		for (SelectBody plainSelect : setOpList.getSelects()) {
			plainSelect.accept(new SelectVisitorImpl());
		}
	}

	// with项
	@Override
	public void visit(WithItem withItem) {
		withItem.getSelectBody().accept(new SelectVisitorImpl());
	}

}
SelectItemVisitorImpl.class

package com.test.sqlparser.visitor;

import net.sf.jsqlparser.statement.select.AllColumns;
import net.sf.jsqlparser.statement.select.AllTableColumns;
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import net.sf.jsqlparser.statement.select.SelectItemVisitor;

public class SelectItemVisitorImpl implements SelectItemVisitor {

	@Override
	public void visit(AllColumns allColumns) {
	}
	
	@Override
	public void visit(AllTableColumns allTableColumns) {
	}

	@Override
	public void visit(SelectExpressionItem selectExpressionItem) {
		selectExpressionItem.getExpression().accept(new ExpressionVisitorImpl());
	}

}
ExpressionVisitorImpl.class

package com.test.sqlparser.visitor;


import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.sf.jsqlparser.expression.AllComparisonExpression;
import net.sf.jsqlparser.expression.AnalyticExpression;
import net.sf.jsqlparser.expression.AnyComparisonExpression;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.CaseExpression;
import net.sf.jsqlparser.expression.CastExpression;
import net.sf.jsqlparser.expression.DateTimeLiteralExpression;
import net.sf.jsqlparser.expression.DateValue;
import net.sf.jsqlparser.expression.DoubleValue;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitor;
import net.sf.jsqlparser.expression.ExtractExpression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.HexValue;
import net.sf.jsqlparser.expression.IntervalExpression;
import net.sf.jsqlparser.expression.JdbcNamedParameter;
import net.sf.jsqlparser.expression.JdbcParameter;
import net.sf.jsqlparser.expression.JsonExpression;
import net.sf.jsqlparser.expression.KeepExpression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.MySQLGroupConcat;
import net.sf.jsqlparser.expression.NullValue;
import net.sf.jsqlparser.expression.NumericBind;
import net.sf.jsqlparser.expression.OracleHierarchicalExpression;
import net.sf.jsqlparser.expression.OracleHint;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.RowConstructor;
import net.sf.jsqlparser.expression.SignedExpression;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.TimeKeyExpression;
import net.sf.jsqlparser.expression.TimeValue;
import net.sf.jsqlparser.expression.TimestampValue;
import net.sf.jsqlparser.expression.UserVariable;
import net.sf.jsqlparser.expression.WhenClause;
import net.sf.jsqlparser.expression.WithinGroupExpression;
import net.sf.jsqlparser.expression.operators.arithmetic.Addition;
import net.sf.jsqlparser.expression.operators.arithmetic.BitwiseAnd;
import net.sf.jsqlparser.expression.operators.arithmetic.BitwiseOr;
import net.sf.jsqlparser.expression.operators.arithmetic.BitwiseXor;
import net.sf.jsqlparser.expression.operators.arithmetic.Concat;
import net.sf.jsqlparser.expression.operators.arithmetic.Division;
import net.sf.jsqlparser.expression.operators.arithmetic.Modulo;
import net.sf.jsqlparser.expression.operators.arithmetic.Multiplication;
import net.sf.jsqlparser.expression.operators.arithmetic.Subtraction;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.Between;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.IsNullExpression;
import net.sf.jsqlparser.expression.operators.relational.LikeExpression;
import net.sf.jsqlparser.expression.operators.relational.Matches;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.expression.operators.relational.RegExpMatchOperator;
import net.sf.jsqlparser.expression.operators.relational.RegExpMySQLOperator;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.select.SubSelect;
import net.sf.jsqlparser.statement.select.WithItem;

public class ExpressionVisitorImpl implements ExpressionVisitor {

	Logger logger =LoggerFactory.getLogger(ExpressionVisitorImpl.class);
	
	// 单表达式
	@Override
	public void visit(SignedExpression signedExpression) {
		signedExpression.accept(new ExpressionVisitorImpl());
	}

	// jdbc参数
	@Override
	public void visit(JdbcParameter jdbcParameter) {
	}

	// jdbc参数
	@Override
	public void visit(JdbcNamedParameter jdbcNamedParameter) {
	}

	// 
	@Override
	public void visit(Parenthesis parenthesis) {
		parenthesis.getExpression().accept(new ExpressionVisitorImpl());
	}

	// between
	@Override
	public void visit(Between between) {
		between.getLeftExpression().accept(new ExpressionVisitorImpl());
		between.getBetweenExpressionStart().accept(new ExpressionVisitorImpl());
		between.getBetweenExpressionEnd().accept(new ExpressionVisitorImpl());
	}

	// in表达式
	@Override
	public void visit(InExpression inExpression) {
		if (inExpression.getLeftExpression() != null) {
			inExpression.getLeftExpression()
					.accept(new ExpressionVisitorImpl());
		} else if (inExpression.getLeftItemsList() != null) {
			inExpression.getLeftItemsList().accept(new ItemsListVisitorImpl());
		}
		inExpression.getRightItemsList().accept(new ItemsListVisitorImpl());
	}

	// 子查询
	@Override
	public void visit(SubSelect subSelect) {
		if (subSelect.getWithItemsList() != null) {
			for (WithItem withItem : subSelect.getWithItemsList()) {
				withItem.accept(new SelectVisitorImpl());
			}
		}
		subSelect.getSelectBody().accept(new SelectVisitorImpl());
	}

	// exist
	@Override
	public void visit(ExistsExpression existsExpression) {
		existsExpression.getRightExpression().accept(
				new ExpressionVisitorImpl());
	}

	// allComparisonExpression??
	@Override
	public void visit(AllComparisonExpression allComparisonExpression) {
		allComparisonExpression.getSubSelect().getSelectBody()
				.accept(new SelectVisitorImpl());
	}

	// anyComparisonExpression??
	@Override
	public void visit(AnyComparisonExpression anyComparisonExpression) {
		anyComparisonExpression.getSubSelect().getSelectBody()
				.accept(new SelectVisitorImpl());
	}

	// oexpr??
	@Override
	public void visit(OracleHierarchicalExpression oexpr) {
		if (oexpr.getStartExpression() != null) {
			oexpr.getStartExpression().accept(this);
		}

		if (oexpr.getConnectExpression() != null) {
			oexpr.getConnectExpression().accept(this);
		}
	}

	// rowConstructor?
	@Override
	public void visit(RowConstructor rowConstructor) {
		for (Expression expr : rowConstructor.getExprList().getExpressions()) {
			expr.accept(this);
		}
	}

	// cast
	@Override
	public void visit(CastExpression cast) {
		cast.getLeftExpression().accept(new ExpressionVisitorImpl());
	}

	// 加法
	@Override
	public void visit(Addition addition) {
		visitBinaryExpression(addition);
	}

	// 除法
	@Override
	public void visit(Division division) {
		visitBinaryExpression(division);
	}

	// 乘法
	@Override
	public void visit(Multiplication multiplication) {
		visitBinaryExpression(multiplication);
	}

	// 减法
	@Override
	public void visit(Subtraction subtraction) {
		visitBinaryExpression(subtraction);
	}

	// and表达式
	@Override
	public void visit(AndExpression andExpression) {
		visitBinaryExpression(andExpression);
	}

	// or表达式
	@Override
	public void visit(OrExpression orExpression) {
		visitBinaryExpression(orExpression);
	}

	// 等式
	@Override
	public void visit(EqualsTo equalsTo) {
		visitBinaryExpression(equalsTo);
	}

	// 大于
	@Override
	public void visit(GreaterThan greaterThan) {
		visitBinaryExpression(greaterThan);
	}

	// 大于等于
	@Override
	public void visit(GreaterThanEquals greaterThanEquals) {
		visitBinaryExpression(greaterThanEquals);
	}

	// like表达式
	@Override
	public void visit(LikeExpression likeExpression) {
		visitBinaryExpression(likeExpression);
	}

	// 小于
	@Override
	public void visit(MinorThan minorThan) {
		visitBinaryExpression(minorThan);
	}

	// 小于等于
	@Override
	public void visit(MinorThanEquals minorThanEquals) {
		visitBinaryExpression(minorThanEquals);
	}

	// 不等于
	@Override
	public void visit(NotEqualsTo notEqualsTo) {
		visitBinaryExpression(notEqualsTo);
	}

	// concat
	@Override
	public void visit(Concat concat) {
		visitBinaryExpression(concat);
	}

	// matches?
	@Override
	public void visit(Matches matches) {
		visitBinaryExpression(matches);
	}

	// bitwiseAnd位运算?
	@Override
	public void visit(BitwiseAnd bitwiseAnd) {
		visitBinaryExpression(bitwiseAnd);
	}

	// bitwiseOr?
	@Override
	public void visit(BitwiseOr bitwiseOr) {
		visitBinaryExpression(bitwiseOr);
	}

	// bitwiseXor?
	@Override
	public void visit(BitwiseXor bitwiseXor) {
		visitBinaryExpression(bitwiseXor);
	}

	// 取模运算modulo?
	@Override
	public void visit(Modulo modulo) {
		visitBinaryExpression(modulo);
	}

	// rexp??
	@Override
	public void visit(RegExpMatchOperator rexpr) {
		visitBinaryExpression(rexpr);
	}

	// regExpMySQLOperator??
	@Override
	public void visit(RegExpMySQLOperator regExpMySQLOperator) {
		visitBinaryExpression(regExpMySQLOperator);
	}

	// 二元表达式
	public void visitBinaryExpression(BinaryExpression binaryExpression) {
		binaryExpression.getLeftExpression()
				.accept(new ExpressionVisitorImpl());
		binaryExpression.getRightExpression().accept(
				new ExpressionVisitorImpl());
	}

	// -------------------------下面都是没用到的-----------------------------------

	// aexpr??
	@Override
	public void visit(AnalyticExpression aexpr) {
	}

	// wgexpr??
	@Override
	public void visit(WithinGroupExpression wgexpr) {
	}

	// eexpr??
	@Override
	public void visit(ExtractExpression eexpr) {
	}

	// iexpr??
	@Override
	public void visit(IntervalExpression iexpr) {
	}

	// jsonExpr??
	@Override
	public void visit(JsonExpression jsonExpr) {
	}

	// hint?
	@Override
	public void visit(OracleHint hint) {
	}

	// timeKeyExpression?
	@Override
	public void visit(TimeKeyExpression timeKeyExpression) {
	}

	// caseExpression?
	@Override
	public void visit(CaseExpression caseExpression) {
	}

	// when?
	@Override
	public void visit(WhenClause whenClause) {
	}

	// var??
	@Override
	public void visit(UserVariable var) {
	}

	// bind?
	@Override
	public void visit(NumericBind bind) {
	}

	// aexpr?
	@Override
	public void visit(KeepExpression aexpr) {
	}

	// groupConcat?
	@Override
	public void visit(MySQLGroupConcat groupConcat) {
	}

	// table列
	@Override
	public void visit(Column tableColumn) {
	}

	// double类型值
	@Override
	public void visit(DoubleValue doubleValue) {
	}

	// long类型值
	@Override
	public void visit(LongValue longValue) {
	}

	// 16进制类型值
	@Override
	public void visit(HexValue hexValue) {
	}

	// date类型值
	@Override
	public void visit(DateValue dateValue) {
	}

	// time类型值
	@Override
	public void visit(TimeValue timeValue) {
	}

	// 时间戳类型值
	@Override
	public void visit(TimestampValue timestampValue) {
	}

	// 空值
	@Override
	public void visit(NullValue nullValue) {
	}

	// 方法
	@Override
	public void visit(Function function) {
	}

	// 字符串类型值
	@Override
	public void visit(StringValue stringValue) {
	}

	// is null表达式
	@Override
	public void visit(IsNullExpression isNullExpression) {
	}

	// literal?
	@Override
	public void visit(DateTimeLiteralExpression literal) {
	}
}
FromItemVisitorImpl.class

package com.test.sqlparser.visitor;

import java.util.ArrayList;
import java.util.List;

import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.Between;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.IsNullExpression;
import net.sf.jsqlparser.expression.operators.relational.LikeExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.FromItemVisitor;
import net.sf.jsqlparser.statement.select.LateralSubSelect;
import net.sf.jsqlparser.statement.select.SubJoin;
import net.sf.jsqlparser.statement.select.SubSelect;
import net.sf.jsqlparser.statement.select.TableFunction;
import net.sf.jsqlparser.statement.select.ValuesList;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.test.entity.TableCondition;
import com.test.security.UserUtils;

public class FromItemVisitorImpl implements FromItemVisitor {

	private static Logger logger = LoggerFactory
			.getLogger(FromItemVisitorImpl.class);
	// 声明增强条件
	private Expression enhancedCondition;

	// FROM 表名 <----主要的就是这个,判断用户对这个表有没有权限
	@Override
	public void visit(Table tableName) {
		//判断该表是否是需要操作的表
		if (isActionTable(tableName.getFullyQualifiedName())) {
			//根据表名获取该用户对于该表的限制条件
			List test = UserUtils.getTableCondition(tableName.getFullyQualifiedName().toUpperCase());
				//If the TableConditionList is exist
				if (test!=null) {
					//增强sql
					for (TableCondition tableCondition : test) {
						// 声明表达式数组
						Expression[] expressions;
						// 如果操作符是between
						if ("between".equalsIgnoreCase(tableCondition.getOperator())|| "not between".equalsIgnoreCase(tableCondition.getOperator())) {
							//expressions = new Expression[] { new LongValue(tableCondition.getFieldName()),new LongValue(tableCondition.getOperator()),new LongValue(tableCondition.getFieldValue()) };
						} else if ("is null".equalsIgnoreCase(tableCondition.getOperator())|| "is not null".equalsIgnoreCase(tableCondition.getOperator())) {
							// 如果操作符是 is null或者是is not null的时候
							//expressions = new Expression[] { new LongValue(	tableCondition.getFieldName()) };
						} else {
							// 其他情况,也就是最常用的情况,比如where   1 = 1
							Column column = new Column(new Table(tableName.getAlias()!=null?tableName.getAlias().getName():tableName.getFullyQualifiedName()), tableCondition.getFieldName());
							if ("1".equals(tableCondition.getFieldName())) {
								expressions = new Expression[] {new LongValue(tableCondition.getFieldName()),new LongValue(tableCondition.getFieldValue())};
							}else{
								expressions = new Expression[] {column,new StringValue(tableCondition.getFieldValue())};
							}
						}
						// 根据运算符对原始数据进行拼接
						Expression operator = this.getOperator(
								tableCondition.getOperator(), expressions);
						if (this.enhancedCondition != null) {
								enhancedCondition = new AndExpression(enhancedCondition , operator);
						} else {
							enhancedCondition = operator;
						}
					}
				}
		}
	}

	// FROM 子查询
	@Override
	public void visit(SubSelect subSelect) {
		// 如果是子查询的话返回到select接口实现类
		subSelect.getSelectBody().accept(new SelectVisitorImpl());
	}

	// FROM subjoin
	@Override
	public void visit(SubJoin subjoin) {
		subjoin.getLeft().accept(new FromItemVisitorImpl());
		subjoin.getJoin().getRightItem().accept(new FromItemVisitorImpl());
	}

	// FROM 横向子查询 
	@Override
	public void visit(LateralSubSelect lateralSubSelect) {
		lateralSubSelect.getSubSelect().getSelectBody()
				.accept(new SelectVisitorImpl());
	}

	// FROM value列表
	@Override
	public void visit(ValuesList valuesList) {
	}

	// FROM tableFunction
	@Override
	public void visit(TableFunction tableFunction) {
	}

	// 将字符串类型的运算符转换成数据库运算语句
	private Expression getOperator(String op, Expression[] exp) {
		if ("=".equals(op)) {
			EqualsTo eq = new EqualsTo();
			eq.setLeftExpression(exp[0]);
			eq.setRightExpression(exp[1]);
			return eq;
		} else if (">".equals(op)) {
			GreaterThan gt = new GreaterThan();
			gt.setLeftExpression(exp[0]);
			gt.setRightExpression(exp[1]);
			return gt;
		} else if (">=".equals(op)) {
			GreaterThanEquals geq = new GreaterThanEquals();
			geq.setLeftExpression(exp[0]);
			geq.setRightExpression(exp[1]);
			return geq;
		} else if ("<".equals(op)) {
			MinorThan mt = new MinorThan();
			mt.setLeftExpression(exp[0]);
			mt.setRightExpression(exp[1]);
			return mt;
		} else if ("<=".equals(op)) {
			MinorThanEquals leq = new MinorThanEquals();
			leq.setLeftExpression(exp[0]);
			leq.setRightExpression(exp[1]);
			return leq;
		} else if ("<>".equals(op)) {
			NotEqualsTo neq = new NotEqualsTo();
			neq.setLeftExpression(exp[0]);
			neq.setRightExpression(exp[1]);
			return neq;
		} else if ("is null".equalsIgnoreCase(op)) {
			IsNullExpression isNull = new IsNullExpression();
			isNull.setNot(false);
			isNull.setLeftExpression(exp[0]);
			return isNull;
		} else if ("is not null".equalsIgnoreCase(op)) {
			IsNullExpression isNull = new IsNullExpression();
			isNull.setNot(true);
			isNull.setLeftExpression(exp[0]);
			return isNull;
		} else if ("like".equalsIgnoreCase(op)) {
			LikeExpression like = new LikeExpression();
			like.setNot(false);
			like.setLeftExpression(exp[0]);
			like.setRightExpression(exp[1]);
			return like;
		} else if ("not like".equalsIgnoreCase(op)) {
			LikeExpression nlike = new LikeExpression();
			nlike.setNot(true);
			nlike.setLeftExpression(exp[0]);
			nlike.setRightExpression(exp[1]);
			return nlike;
		} else if ("between".equalsIgnoreCase(op)) {
			Between bt = new Between();
			bt.setNot(false);
			bt.setLeftExpression(exp[0]);
			bt.setBetweenExpressionStart(exp[1]);
			bt.setBetweenExpressionEnd(exp[2]);
			return bt;
		} else if ("not between".equalsIgnoreCase(op)) {
			Between bt = new Between();
			bt.setNot(true);
			bt.setLeftExpression(exp[0]);
			bt.setBetweenExpressionStart(exp[1]);
			bt.setBetweenExpressionEnd(exp[2]);
			return bt;
		} else {
			// 如果没有该运算符对应的语句
			return null;
		}

	}

	public Expression getEnhancedCondition() {
		return enhancedCondition;
	}

	// 判断传入的table是否是要进行操作的table
	public boolean isActionTable(String tableName) {
		// 默认为操作
		boolean flag = true;
		// 无需操作的表的表名
		List tableNames = new ArrayList();
		// 由于sql可能格式不规范可能表名会存在小写,故全部转换成大写,最上面的方法一样
		if (tableNames.contains(tableName.toUpperCase())) {
			// 如果表名在过滤条件中则将flag改为flase
			flag = false;
		}
		return flag;
	}

}


完整的拦截器代码

package com.test.interceptor;

import java.io.StringReader;
import java.sql.Connection;

import java.util.Properties;

import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.Select;

import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;

import com.test.sqlparser.visitor.SelectVisitorImpl;


@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }) })
public class MyInterceptor implements Interceptor {
	CCJSqlParserManager parserManager = new CCJSqlParserManager();
	@Override
	public Object intercept(Invocation invocation) throws Throwable {
		StatementHandler handler = (StatementHandler)invocation.getTarget();
		//由于mappedStatement为protected的,所以要通过反射获取
		MetaObject statementHandler = SystemMetaObject.forObject(handler);
		//mappedStatement中有我们需要的方法id
		MappedStatement mappedStatement = (MappedStatement) statementHandler.getValue("delegate.mappedStatement");
		//获取sql
		BoundSql boundSql = handler.getBoundSql();
		String sql = boundSql.getSql();
		//获取方法id
		String id = mappedStatement.getId();
		//获得方法类型
		SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
		if ("需要增强的方法的id".equals(id)) {
			//增强sql代码块
			if ("SELECT".equals(sqlCommandType)) {
				//如果是select就将sql转成SELECT对象
				Select select = (Select)parserManager.parse(new StringReader(sql));
				//访问各个visitor
				select.getSelectBody().accept(new SelectVisitorImpl());
				//将增强后的sql放回
				statementHandler.setValue("delegate.boundSql.sql",select.toString());
			}
		}
		return invocation.proceed();
	}

	@Override
	public Object plugin(Object target) {
		return Plugin.wrap(target, this);
	}

	@Override
	public void setProperties(Properties properties) {
	}

}






















你可能感兴趣的:(Java)