利用mybatis插件简化批量插入/更新

批量插入/更新的痛点

sql中的批量插入语法如下:

insert into user (id, name) values (?, ?), (?, ?), (?, ?)

对于mybatis,使用方式:

 
   insert into user (id, name) values
   
    (#{user.id}, #{user.name})
  

如果是使用注解,则更难用:

  @Insert({
      ""
  })
  int batchInsert(List users);

优化目标

希望通过mybatis插件达到按如下方式执行batch操作:

  • 注解使用方式
  @Insert("insert into user (id, name) values (#{id}), #{name}")
  int batchInsert(List users);

  @Update("update user set name = #{name} where id = #{id}")
  int batchUpdate(List users);
  • XML使用方式

   insert into user (id, name) values (#{id}, #{name})


 
   update user set name = #{name} where id = #{id}

实现方式

原理:通过mybatis插件识别出以batch开头的insert/update,并通过mybatis的BatchExecutor执行批量插入/更新(详情见代码库:https://github.com/gaohanghbut/stupidmybatis):


import cn.yxffcode.stupidmybatis.commons.BatchUtils;
import cn.yxffcode.stupidmybatis.commons.ExecutorUtils;
import cn.yxffcode.stupidmybatis.commons.Reflections;
import org.apache.ibatis.executor.BatchExecutor;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.Configuration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.SQLException;
import java.util.Properties;

/**
 * 方便使用的批量更新插件,只需要sql statement id以batch开头,参数为Iterable或者数组即可.
 * 

* 限制:最好是作为第一个拦截器使用,因为在它之前的拦截器不会被调用 * * @author gaohang on 16/7/29. */ @Intercepts({ @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}), }) public class BatchExecutorInterceptor implements Interceptor { private static final Logger LOGGER = LoggerFactory.getLogger(BatchExecutorInterceptor.class); @Override public Object intercept(final Invocation invocation) throws Throwable { //check argument if (invocation.getArgs()[1] == null) { return invocation.proceed(); } final MappedStatement ms = (MappedStatement) invocation.getArgs()[0]; //if it should use batch if (!BatchUtils.shouldDoBatch(ms.getId())) { return invocation.proceed(); } //create batch executor final Executor targetExecutor = ExecutorUtils.getTargetExecutor((Executor) invocation.getTarget()); if (targetExecutor instanceof BatchExecutor) { return invocation.proceed(); } final Configuration configuration = (Configuration) Reflections.getField("configuration", targetExecutor); final BatchExecutor batchExecutor = new BatchExecutorAdaptor(configuration, targetExecutor.getTransaction()); try { return batchExecutor.update(ms, invocation.getArgs()[1]); } catch (SQLException e) { batchExecutor.flushStatements(true); throw e; } } @Override public Object plugin(final Object target) { if (!(target instanceof Executor)) { return target; } if (target instanceof BatchExecutor) { return target; } return Plugin.wrap(target, this); } @Override public void setProperties(final Properties properties) { } }

其中使用到了BatchExecutorAdaptor,它是BatchExecutor的子类:


import com.google.common.base.Throwables;
import org.apache.ibatis.executor.BatchExecutor;
import org.apache.ibatis.executor.BatchResult;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.transaction.Transaction;

import java.sql.SQLException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

/**
 * @author gaohang on 7/30/16.
 */
final class BatchExecutorAdaptor extends BatchExecutor {
  public BatchExecutorAdaptor(Configuration configuration, Transaction transaction) {
    super(configuration, transaction);
  }

  @Override
  public int update(MappedStatement ms, Object parameter) throws SQLException {
    if (parameter == null) {
      super.update(ms, parameter);
    }
    final Object params;
    if (parameter instanceof Map) {
      final Map paramMap = (Map) parameter;
      if (paramMap == null || paramMap.size() != 1) {
        if (!paramMap.containsKey("list")) {
          return super.update(ms, parameter);
        } else {
          params = paramMap.get("list");
        }
      } else {
        params = paramMap.values().iterator().next();
      }
    } else if (parameter instanceof Iterable || parameter.getClass().isArray()) {
      params = parameter;
    } else {
      params = Collections.singletonList(parameter);
    }
    final Iterable paramIterable = toIterable(params);
    try {
      for (Object obj : paramIterable) {
        super.update(ms, obj);
      }
      List batchResults = doFlushStatements(false);
      if (batchResults == null || batchResults.size() == 0) {
        return 0;
      }
      return resolveUpdateResult(batchResults);
    } catch (Exception e) {
      doFlushStatements(true);
      Throwables.propagate(e);
      return 0;
    }
  }


  private Iterable toIterable(final Object params) {
    if (params == null) {
      return Collections.emptyList();
    }
    Iterable paramIterable;
    if (params instanceof Iterable) {
      paramIterable = (Iterable) params;
    } else if (params.getClass().isArray()) {
      Object[] array = (Object[]) params;
      paramIterable = Arrays.asList(array);
    } else {
      paramIterable = Collections.singletonList(params);
    }
    return paramIterable;
  }

  private int resolveUpdateResult(final List batchResults) {
    int result = 0;
    for (BatchResult batchResult : batchResults) {
      int[] updateCounts = batchResult.getUpdateCounts();
      if (updateCounts == null || updateCounts.length == 0) {
        continue;
      }
      for (int updateCount : updateCounts) {
        result += updateCount;
      }
    }
    return result;
  }

}

使用到的utils类:

/**
 * @author gaohang on 16/7/31.
 */
public abstract class BatchUtils {
  private BatchUtils() {
  }

  public static boolean shouldDoBatch(final String statementId) {
    return statementId.startsWith("batch", statementId.lastIndexOf('.') + 1);
  }
}

/**
 * @author gaohang on 16/8/1.
 */
public abstract class ExecutorUtils {
  private ExecutorUtils() {
  }

  public static Executor getTargetExecutor(final Executor executor) {
    Executor targetExecutor = executor;
    while (targetExecutor instanceof Proxy) {
      targetExecutor = (Executor) Reflections.getField("target",
          Proxy.getInvocationHandler(targetExecutor));
    }
    //取真正的executor
    if (targetExecutor instanceof CachingExecutor) {
      targetExecutor = (Executor) Reflections.getField("delegate", targetExecutor);
    }
    return targetExecutor;
  }
}

import com.google.common.base.Throwables;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;

/**
 * @author gaohang on 15/12/4.
 */
public final class Reflections {
  private Reflections() {
  }

  private static Field findField(Class clazz, String name) {
    return findField(clazz, name, null);
  }

  public static Field findField(Class clazz, String name, Class type) {
    Class searchType = clazz;
    while (!Object.class.equals(searchType) && searchType != null) {
      Field[] fields = searchType.getDeclaredFields();
      for (Field field : fields) {
        if ((name == null || name.equals(field.getName())) && (type == null || type
            .equals(field.getType()))) {
          return field;
        }
      }
      searchType = searchType.getSuperclass();
    }
    return null;
  }

  public static Object getField(String fieldName, Object target) {
    Field field = findField(target.getClass(), fieldName);
    if (!field.isAccessible()) {
      field.setAccessible(true);
    }
    try {
      return field.get(target);
    } catch (IllegalAccessException ex) {
      throw new IllegalStateException("Unexpected reflection exception - " + ex.getClass()
          .getName() + ": " + ex.getMessage(), ex);
    }
  }

  public static void setField(Object target, String fieldName, Object value) {
    try {
      final Field field = target.getClass().getDeclaredField(fieldName);
      if (!field.isAccessible()) {
        field.setAccessible(true);
      }
      field.set(target, value);
    } catch (Exception ex) {
      throw new IllegalStateException("Unexpected reflection exception - " + ex.getClass()
          .getName() + ": " + ex.getMessage(), ex);
    }
  }

  public static  T call(Annotation annotation, String methodName) {
    try {
      Method method = annotation.annotationType().getMethod(methodName);
      return (T) method.invoke(annotation);
    } catch (Exception e) {
      throw Throwables.propagate(e);
    }
  }
}

你可能感兴趣的:(利用mybatis插件简化批量插入/更新)