直接上代码
package xxx;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
import org.apache.commons.lang3.StringUtils;
import org.springframework.cglib.proxy.Enhancer;
import org.springframework.cglib.proxy.MethodInterceptor;
import org.springframework.cglib.proxy.MethodProxy;
import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.LoggerContext;
import ch.qos.logback.classic.selector.ContextSelector;
import cn.hutool.db.sql.SqlFormatter;
public class CustomDefaultContextSelector implements ContextSelector, MethodInterceptor {
private LoggerContext defaultLoggerContext;
private LoggerContext proxyedDefaultLoggerContext;
private static ConcurrentHashMap<String, org.slf4j.Logger> cachedLogger = new ConcurrentHashMap<>(
1000);
public CustomDefaultContextSelector(LoggerContext context) {
this.defaultLoggerContext = context;
}
@Override
public LoggerContext getLoggerContext() {
return getDefaultLoggerContext();
}
@Override
public LoggerContext getDefaultLoggerContext() {
if (proxyedDefaultLoggerContext == null) {
Enhancer enhancer = new Enhancer();
enhancer.setSuperclass(defaultLoggerContext.getClass());
enhancer.setCallback(this);
proxyedDefaultLoggerContext = (LoggerContext) enhancer.create();
}
return proxyedDefaultLoggerContext;
}
@Override
public LoggerContext detachLoggerContext(String loggerContextName) {
return defaultLoggerContext;
}
@Override
public List<String> getContextNames() {
return Arrays.asList(defaultLoggerContext.getName());
}
@Override
public LoggerContext getLoggerContext(String name) {
if (defaultLoggerContext.getName().equals(name)) {
return defaultLoggerContext;
} else {
return null;
}
}
@Override
public Object intercept(Object o, Method method, Object[] args, MethodProxy methodProxy) throws Throwable {
Object result;
result = methodProxy.invokeSuper(o, args);
if (Objects.equals(method.getReturnType().getName(), org.slf4j.Logger.class.getName())
&& Objects.equals(method.getName(), "getLogger")) {
org.slf4j.Logger logger = (org.slf4j.Logger) result;
String loggerName = logger.getName();
if (!loggerName.startsWith("cn.com.xxx")) {
return result;
}
if (cachedLogger.get(loggerName) != null) {
return cachedLogger.get(loggerName);
}
CustomLoggerInterceptor customLoggerInterceptor = new CustomLoggerInterceptor();
customLoggerInterceptor.setLogger((Logger) result);
Object newProxyInstance = Proxy.newProxyInstance(result.getClass().getClassLoader(),
result.getClass().getInterfaces(),
customLoggerInterceptor);
cachedLogger.put(loggerName, (org.slf4j.Logger) newProxyInstance);
return newProxyInstance;
}
return result;
}
public static ConcurrentHashMap<String, org.slf4j.Logger> getCachedLogger() {
return cachedLogger;
}
public static class CustomLoggerInterceptor implements InvocationHandler {
private static final String PARAMETERS = "==> Parameters: ";
private static final String PREPARING = "==> Preparing: ";
private static final Set<String> skipMethodSet = new HashSet<String>();
Pattern pattern = Pattern.compile(
"\\(String\\),{0,1}|\\(Timestamp\\),{0,1}|\\(Date\\),{0,1}|\\(Time\\),{0,1}|\\(LocalDate\\),{0,1}|\\(LocalTime\\),{0,1}|\\(LocalDateTime\\),{0,1}|\\(Byte\\),{0,1}|\\(Short\\),{0,1}|\\(Integer\\),{0,1}|\\(Long\\),{0,1}|\\(Float\\),{0,1}|\\(Double\\),{0,1}|\\(BigDecimal\\),{0,1}|\\(Boolean\\),{0,1}|\\(Null\\),{0,1}");
public ThreadLocal<String> threadLocal = new ThreadLocal<>();
static {
skipMethodSet.add("isTraceEnabled");
skipMethodSet.add("isDebugEnabled");
skipMethodSet.add("isInfoEnabled");
skipMethodSet.add("isWarnEnabled");
skipMethodSet.add("isErrorEnabled");
}
private Logger logger;
public void setLogger(Logger logger) {
this.logger = logger;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (logger.getName().startsWith("cn.com.xxx")) {
if (skipMethodSet.contains(method.getName())) {
return method.invoke(logger, args);
}
Object orginResult = method.invoke(logger, args);
try {
@SuppressWarnings("unused")
String testString = (String) args[3];
} catch (Exception e) {
return method.invoke(logger, args);
}
String startString = (String) args[3];
if (startString.startsWith(PREPARING)) {
threadLocal.set(startString.replace(PREPARING, ""));
}
if (startString.startsWith(PARAMETERS)) {
extracted(startString);
}
return orginResult;
}
return method.invoke(logger, args);
}
private void extracted(String rawParams) {
rawParams = rawParams.replace(PARAMETERS, "");
String[] rawParamArray = rawParams.split(",");
String[] paramArray = pattern.split(rawParams);
String sql = threadLocal.get();
try {
if (StringUtils.isNotBlank(sql) && sql.contains("?")) {
for (int i = 0; i < paramArray.length; i++) {
String value = paramArray[i];
value = StringUtils.stripStart(value, null);
if (rawParamArray[i].contains("String")
|| rawParamArray[i].contains("Timestamp")
|| rawParamArray[i].contains("Date")
|| rawParamArray[i].contains("Time")
|| rawParamArray[i].contains("LocalDate")
|| rawParamArray[i].contains("LocalDateTime")
|| rawParamArray[i].contains("LocalTime")
|| rawParamArray[i].contains("LocalDateTime")) {
value = "'" + value + "'";
}
sql = sql.replaceFirst("\\?", value);
}
printSQL(SqlFormatter.format(sql));
}
} catch (Exception e) {
e.printStackTrace();
} finally {
threadLocal.remove();
}
}
public final void printSQL(String sqlString) {
if (sqlString.contains(PREPARING)) {
sqlString = sqlString.substring(PREPARING.length());
}
System.out.println("\033[32;4m"
+ "\r\n;;-- ------------------------------------------------------------------------------------------------------------------------------\r\n"
+ "\r\n"
+ sqlString
+ ";"
+ "\r\n"
+ "\r\n;;-- ------------------------------------------------------------------------------------------------------------------------------\r\n"
+ "\033[0m");
}
}
}