老年人教程:MyBatis拦截器动态修改SQL(更新与插入)语句

注:本文编写与 2019年12月17日, 内容可能存在时效性问题。
数据库使用MySQL5.7 集成于SpringBoot 2.0.X , 引用国产的开源工具类Hutool

本教程建议显示大纲视图 配合食用

一 、简介

1. 设定使用场景

任意 insert 、update 语句,都需要记录下语句的操作用户(CREATOR), 但由于系统早期设计不规范,大量SQL语句中,部分有写入CREATOR部分没有CREATOR,现在需要设计一段代码,能够自动拦截下所有的insert和update 语句,自动添加user 字段,及其字段值;

例如插入语句:

insert into table(id,name,age) values("efdefbb1970b486f8985fa19ab3ab22e","筱黄舒",12);
# 拦截后修改成
insert into table(id,name,age,createor) values("efdefbb1970b486f8985fa19ab3ab22e","筱黄舒",12,"admin");

2. 教程所用到的框架

此处列出 pom.xml 所依赖的一些包的版本,其中用到hutool这个工具类

<properties>
  <java.version>1.8</java.version>
  <hutool-all.version>4.5.1</hutool-all.version>
  <spring-boot-start.version>2.0.7.RELEASE</spring-boot-start.version>
  <druid-spring-boot-starter.version>1.1.10</druid-spring-boot-starter.version>
  <mysql-connector-java.version>5.1.47</mysql-connector-java.version>
  <spring-boot-start-mybatis.version>2.0.0</spring-boot-start-mybatis.version>
  .......
</properties>

二、拦截器

2.1 创建拦截器

@Intercepts({@Signature(
      type = Executor.class, method = "update",
      args = {MappedStatement.class, Object.class})})
@Component
public class demo implements Interceptor {
   private static final int PARAMETER_INDEX =  1;

private static final int MAPPED_STATEMENT_INDEX  = NumConstant.COMMON_NUMBER_ZERO;

private static final String UPDATE_SQL_TYPE = "Update";
private static final String INSERT_SQL_TYPE = "Insert";

private Logger logger  = LoggerFactory.getLogger(ModifyInterceptor.class);

private static LRUCache<String, Boolean> modifyCache = new LRUCache<>(Byte.MAX_VALUE);
private static LRUCache<String, Boolean> insertCache = new LRUCache<>(Byte.MAX_VALUE);
   
   /**
   * 拦截主要的逻辑
   */   
   @Override
   public Object intercept(Invocation invocation) throws Throwable {
      return null;
   }


   @Override
   public Object plugin(Object o) {
      return o instanceof Executor? Plugin.wrap(o, this):o;
   }


   @Override
   public void setProperties(Properties properties) {
     // not thing todo
   }
   
   //这个东西主要是标注在不需要被拦截器处理的SQL上,flag=false的话会跳过
   @Target({ElementType.METHOD})
   @Retention(RetentionPolicy.RUNTIME)
   public @interface AutoModify {
      boolean flag() default true;
   }

/**
 * 判断字段是否存在
 */
@Mapper
@Component
public interface IModifyDAO {
   @Select("select count(*) from information_schema.columns where  table_name = #{tableName} and table_schema  = (select database()) and column_name in ('MODIFIER')")
   int existsModifyColumn( @Param("tableName") String tableName);

   @Select("select count(*) from information_schema.columns where  table_name = #{tableName} and table_schema  = (select database()) and column_name in ('CREATOR')")
   int existsInsertColumn(String tableName);
}
 
   
}

拦截器主要要实现的方法是 intercept ,

另外在上述类中 , 有两个cache , 是用来存放表是否存在字段的缓存, 有些表不一定存在有user字段,所以在修改SQL前需要判断字段是否存在,但是没必要每次都向数据库确认,因此加上个缓存。

  1. intercept 方法的实现

由于代码太长 ,不把所有细节贴出来,下面会按方法步骤讲解,需要各位看官根据自己业务需要去整合

2.2 获取注解

获取注解主要是判断当且方法所执行的SQL是否需要跳过

private AutoModify getAnnotation(Invocation invocation) throws ClassNotFoundException {
   final Object[]        args            = invocation.getArgs();
   final MappedStatement mappedStatement = getMs(args);
   String                namespace       = mappedStatement.getId();
   String                className       = namespace.substring(MAPPED_STATEMENT_INDEX, namespace.lastIndexOf("."));
   String                methedName      = namespace.substring(namespace.lastIndexOf(".") + PARAMETER_INDEX);
   Method[]              ms              = Class.forName(className).getMethods();
   for (Method m : ms) {
      if (m.getName().equals(methedName)) {
         return m.getAnnotation(AutoModify.class);
      }
   }
   return null;
}

判断是否存在注解

AutoModify annotation = this.getAnnotation(invocation);
if (annotation == null || annotation.flag()) {
  // 拦截器逻辑  
}

使用方法,在dao层添加注解 @AutoModify(flag=false), 这样拦截器会跳过该DAO层方法

@Mapper
public interface UserDao {
  @AutoModify(flag = false)
  User get(String id);
}

2.3 通过BoundSql 获取SQL语句

在实现

public Object intercept(Invocation invocation) throws Throwable{} 

Mybatis 会传来一个Invocation ,我们能够从这个参数中获取到很多Mybatis 执行过程中相关的变量值,SQL语句也在其中
获取SQL语句

private String getBound(Invocation invocation) {
   Object[]        args          = invocation.getArgs();
   MappedStatement ms            = getMs(args);
   Configuration   configuration = ms.getConfiguration();
   Object          parameter     = args[PARAMETER_INDEX];
   Object          target        = invocation.getTarget();
   StatementHandler handler =   configuration.newStatementHandler((Executor) target, ms,
         parameter, RowBounds.DEFAULT, null, null);
   return handler.getBoundSql().getSql();
}

2.4 解析SQL语句 通过Druid 获取到SQL的 Visitor

druid 能够帮我们解析SQL语句,通过visitor 能够访问SQL语句的不同组成部分,例如表名,字段名等

private MySqlSchemaStatVisitor getMySqlSchemaStatVisitor(String sqlStr) {
   MySqlStatementParser   parser       = new MySqlStatementParser(sqlStr);
   SQLStatement           sqlStatement = parser.parseStatementList().get(0);
   MySqlSchemaStatVisitor visitor      = new MySqlSchemaStatVisitor();
   sqlStatement.accept(visitor);
   return visitor;
}

visitor 的使用示例

//获取SQL语句中涉及的所有表 
Map<TableStat.Name, TableStat> tableMap = visitor.getTables(); 
//是否包含某个字段, 如果SQL语句中已经包含我们要添加进去的字段就没必要去修改这个SQL语句了
visitor.containsColumn(table,"modifier")

2.5 改造更新(UPDATE)语句

通过visitor 我们可以得到SQL语句的表结构信息, 根据这些信息我们可以重新构造这个SQL

for (TableStat.Name tableName : tableMap.keySet()) {
   TableStat stat = tableMap.get(tableName);
   String table = tableName.getName();
  
   if (UPDATE_SQL_TYPE.equals(stat.toString()) && auditRule(tableName.getName(),UPDATE_SQL_TYPE)) {
         //auditRule方法是审计规则,根据业务需要去判断,我这里是判断表前缀和表里面是否有modifier字段,具体判断逻辑不列出来了
                 
      }   
}

如何改造一个更新语句,这里使用 Druid 提供的 SQLBuilderFactory , 能够方便的对更新语句进行修改

//sqlStr 是SQL语句, dbType的是数据库类型
SQLUpdateBuilder sqlBuilder = SQLBuilderFactory.createUpdateBuilder(sqlStr, visitor.getDbType());

得到得sqlBuilder 有下面这些接口,基本上就是添加一些where条件,或者添加一些字段什么的
老年人教程:MyBatis拦截器动态修改SQL(更新与插入)语句_第1张图片

下面我们为 SQL语句 添加一个字段( UserUtil.getCurrentUser() 是一个根据当前线程上下文获取用户名的接口,与本次主题关系不大,不列出来了)

SQLUpdateBuilder sqlBuilder = SQLBuilderFactory.createUpdateBuilder(sqlStr, visitor.getDbType());
if(!visitor.containsColumn(table,"modifier")){
   sqlBuilder.set("modifier = '" + UserUtil.getCurrentUser()+"'");
}

然后 sqlBuilder.toString() 看看吧, 看是否多了一个字段

2.6 改造插入(INSERT)语句

INSERT语句比较特殊, 我找了druid 的文档没发现有比较好的接口能够处理这种场景,无奈下查找了下Mybatis 的文档,发现Mybatis 自带一些SQL语句生成的方法,具体思路也是通过visitor获取SQL语句的组成部分,让后通过mybatis 重新组装

private static String convertInsertSQL(String sql, MySqlSchemaStatVisitor visitor) {
		MySqlStatementParser parser = new MySqlStatementParser(sql);
		SQLStatement statement = parser.parseStatement();
		MySqlInsertStatement myStatement = (MySqlInsertStatement) statement;
		String tableName = myStatement.getTableName().getSimpleName();
		List<SQLExpr> columns = myStatement.getColumns();
		List<ValuesClause> vcl = myStatement.getValuesList();
		if (columns == null || columns.size() <= 0 || myStatement.getQuery() != null) {
			return sql;
		}
		return new SQL() {{
			INSERT_INTO(tableName);
			for (int i = 0; i < columns.size(); i++) {
				String column = columns.get(i).toString();
				Object value = vcl.get(0).getValues().get(i);
				//如果是子查询需要添加括号
				if (value instanceof SQLQueryExpr) {
					value = "(" + value.toString() + ")";
				}
				VALUES(column, value.toString());
			}
			if (!visitor.containsColumn(tableName, "CREATOR")) {
				VALUES("CREATOR ", "'" + UserUtil.getCurrentUser() + "'");
			}
		}}.toString();
	}

2.7 应用修改后的SQL

为 invocation.getArgs()[0] 赋值一个新的MappedStatement 即可 , 代码是参考其它网友的,但可惜不能在我电脑运行起来,而且OGNL表达式会丢失, 后面经过一些小调整后正常运行。现在已经找不到当初看的那片文章,如果大家知道请告诉我 ,感谢

invocation.getArgs()[0] = newMappedStatement(getMs(invocation.getArgs()),boundSql,sqlStr);
private MappedStatement newMappedStatement(final MappedStatement ms , BoundSql oldBound ,String sqlStr ){
   MappedStatement newStatement =  copyFromMappedStatement(ms, new BoundSqlSqlSource(oldBound));
   MetaObject msObject =  MetaObject.forObject(newStatement, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(),new DefaultReflectorFactory());
   msObject.setValue("sqlSource.boundSql.sql", sqlStr);
   return newStatement;
}
@SuppressWarnings({ "unchecked", "rawtypes" })
private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource)
{
   MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
   builder.resource(ms.getResource());
   builder.fetchSize(ms.getFetchSize());
   builder.statementType(ms.getStatementType());
   builder.keyGenerator(ms.getKeyGenerator());
   builder.timeout(ms.getTimeout());
   builder.parameterMap(ms.getParameterMap());
   List<ResultMap> resultMaps = new ArrayList<>();
   String          id         = "-inline";
   if (ms.getResultMaps() != null && ms.getResultMaps().size()>0) {
      id = ms.getResultMaps().get(0).getId() + "-inline";
   }
   ResultMap resultMap = new ResultMap.Builder(null, id, Long.class, new ArrayList()).build();
   resultMaps.add(resultMap);
   builder.resultMaps(resultMaps);
   if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
      StringBuilder keyProperties = new StringBuilder();
      for (String keyProperty : ms.getKeyProperties()) {
         keyProperties.append(keyProperty).append(",");
      }
      keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
      builder.keyProperty(keyProperties.toString());
   }
   builder.resultSetType(ms.getResultSetType());
   builder.cache(ms.getCache());
   builder.flushCacheRequired(ms.isFlushCacheRequired());
   builder.useCache(ms.isUseCache());
   builder.databaseId(ms.getDatabaseId());
   return builder.build();
}
static class BoundSqlSqlSource implements SqlSource {
   private BoundSql boundSql;

   BoundSqlSqlSource(BoundSql boundSql) {
      this.boundSql = boundSql;
   }

   @Override
   public BoundSql getBoundSql(Object parameterObject) {
      return boundSql;
   }
}

三、最后

最后简单概括一下 intercept 方法的实现

 
   @Override
   public Object intercept(Invocation invocation) throws Throwable {
      // 1.通过注解判断是否需要处理此SQL
      // 2.获取SQL语句
      // 3.通过Druid解析SQL语句,获取visitor
      // 4.根据visitor 判断SQL语句类型
      // 5.根据语句类型改造SQL语句
      // 6.应用修改后的SQL语句
      return invocation.proceed();
  }

你可能感兴趣的:(Web工程)