主要代码如下:
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
@Component
@Intercepts(@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class}))
public class PrimaryKeyInterceptor implements Interceptor {
static Logger log = LoggerFactory.getLogger(PrimaryKeyInterceptor.class);
private static ConcurrentHashMap seeds = new ConcurrentHashMap<>();
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object[] args = invocation.getArgs();
Connection connection = (Connection) args[0];
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
BoundSql boundSql = statementHandler.getBoundSql();
//MetaObject是Mybatis提供的一个用于方便、优雅访问访问对象的反射工具
//同时它支持对JavaBean、Collection、Map三种类型对象的操作。
MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
//MappedStatement表示的是XML中的一个SQL
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
//SqlCommandType代表SQL类型
SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
if (sqlCommandType == SqlCommandType.INSERT) {
List pkMetas = parsePkHolder(boundSql.getSql());
String tableName = parseTableName(boundSql.getSql());
if (pkMetas.size() > 0 && tableName.length() > 0) {
long newId = generatePkValue(connection, tableName, pkMetas);
String rewriteSql = replacePkValue(boundSql.getSql(), pkMetas, newId);
metaObject.setValue("boundSql.sql", rewriteSql);
}
}
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
/**
* PK元数据信息
*/
static class PkMeta {
int lo;
int hi;
String name;
public PkMeta(String sql, int lo, int hi) {
name = sql.substring(lo + 2, hi - 1).trim();
this.lo = lo;
this.hi = hi;
}
@Override
public String toString() {
return "pk{" + "lo=" + lo + ", hi=" + hi + ", name=" + name + "}";
}
}
/**
* 把SQL语句中[[pkName]]替换成新的主键值
*/
private static String replacePkValue(String sql, List pkMetas, long newId) {
StringBuilder sbRet = new StringBuilder();
int lo = 0;
for (PkMeta m : pkMetas) {
for (int i = lo; i < m.lo; i++) {
sbRet.append(sql.charAt(i));
}
sbRet.append(newId);
lo = m.hi + 1;
}
for (int i = lo; i < sql.length(); i++) {
sbRet.append(sql.charAt(i));
}
return sbRet.toString();
}
private long generatePkValue(Connection conn, String tableName, List pkMetas) throws SQLException {
String tab = tableName.toLowerCase();
AtomicLong seed = seeds.get(tab);
if (seed == null) {
synchronized (seeds) {
if (seed == null) {
long v = 0;
Statement stmt = conn.createStatement();
try {
String sql = "select max(" + pkMetas.get(0).name + ") from " + tableName;
ResultSet rs = stmt.executeQuery(sql);
if (rs.next()) {
v = rs.getLong(1) / 1000;
}
} finally {
stmt.close();
}
seed = new AtomicLong(v);
seeds.put(tab, seed);
}
}
}
return seed.incrementAndGet() * 1000 + 1;
}
/**
* 解析sql中[[Primary Key Name]]元数据到List中。
*
* @param rawSql
* @return
*/
private static List parsePkHolder(String rawSql) throws Exception {
String sql = rawSql == null ? "" : rawSql;
List ret = new ArrayList<>();
Set pks = new HashSet();
for (int i = 0; i < sql.length(); ) {
if (occurAt(sql, "[[", i)) {
for (int j = i + 2; j < sql.length(); j++) {
if (occurAt(sql, "]]", j)) {
PkMeta pkm = new PkMeta(sql, i, j + 1);
ret.add(pkm);
pks.add(pkm.name.toLowerCase());
i = j + 2;
break;
}
}
} else {
i++;
}
}
if (pks.size() > 1) {
log.error("[!_!]Existed more than one primay keys:" + pks + ", SQL:" + rawSql);
ret.clear();//存在多个不同主键名,放弃处理
}
return ret;
}
/**
* @return 解析出 insert into语句中的表名
*/
private static String parseTableName(String rawSql) {
String sql = rawSql == null ? "" : rawSql;
String INSERT = "insert";
String INTO = "into";
for (int i = 0; i < sql.length(); i++) {
if (occurAt(sql, INSERT, i) && isBlank(sql, i - 1) && isBlank(sql, i + INSERT.length())) {
System.out.println(i);
for (int j = i + INSERT.length() + 1; j < sql.length(); j++) {
if (occurAt(sql, INTO, j) && isBlank(sql, j - 1) && isBlank(sql, j + INTO.length())) {
int lo = j + INTO.length() + 1;
while (lo < sql.length() && sql.charAt(lo) <= ' ') {
lo++;
}
int hi = lo + 1;
while (hi < sql.length() && sql.charAt(hi) > ' ') {
hi++;
}
return sql.substring(lo, hi);
}
}
}
}
return "";
}
private static boolean isBlank(String s, int i) {
if (i < 0 || i >= s.length()) {
return true;
}
return s.charAt(i) <= ' ';
}
private static boolean occurAt(String s, String sub, int i) {
int y = 0;
int len = s.length();
int subLen = sub.length();
for (int x = i; x < len && y < subLen; x++, y++) {
char c1 = Character.toUpperCase(s.charAt(x));
char c2 = Character.toUpperCase(sub.charAt(y));
if (c1 != c2) {
return false;
}
}
return y >= subLen;
}
}