刚接触MybatisPlus那会,就觉得它提供的lambda调用方式挺酷的
比如常规写法new QueryWrapper.eq(“name”, “xiaoming”)
改成用lambda就是:
new LambdaQueryWrapper.eq(User::getName, “xiaoming”);
看起来差别不大,但是对于名字比较长的字段名,直接idea自动提示填充就行了,再也不用去copy了,也不用担心会写错啥的,修改字段名也能自动rename,总之就是很爽。
一直都觉得这是个顺理成章的功能,没有想过怎么去实现它,直到这几天想自己实现一个类似的功能,发现其实没这么简单。
正常思维是调用方都把类名和方法名传过去了,拿到解析成name不是很容易吗。但是实际上被调用的方法里只能获取到一个function而已,下面是我写的模拟代码:
public class TestLambdaQuery {
public static void main(String[] args) {
MyLambdaQueryWrapper<User> lambdaQueryWrapper = new MyLambdaQueryWrapper<>();
String sql = lambdaQueryWrapper.eqSql(User::getName, "xiaoming");
Assert.isTrue("name=xiaoming".equals(sql), "expect sql is 'name=xiaoming'");
String sql2 = lambdaQueryWrapper.eqSql(User::getAge, 18);
Assert.isTrue("age=18".equals(sql2), "expect sql is 'age=18'");
}
}
public class MyLambdaQueryWrapper<T> {
public String eqSql(Function<T, Object> function, Object value) {
//需要在这里从function中解析得到name、age等
return "name="+value;
}
}
这里的function就是个对象,
System.out.println(function.getClass());:
class com.example.demo.lambda.TestLambdaQuery$$Lambda$1/1732398722
根据反射可以推测出是TestLambdaQuery的一个内部类,实现了Function接口,并重写了apply方法。
具体可以去了解下JDK8中lambda的实现原理,推荐这篇文章:https://cloud.tencent.com/developer/article/1350313
在main方法第一行添加:
System.setProperty(“jdk.internal.lambda.dumpProxyClasses”, “J:\apps\demo\target\classes”);
即可把生成的class 保存到本地目录,再通过idea反编译打开,如下:
//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by Fernflower decompiler)
//
package com.example.demo.lambda;
import java.lang.invoke.LambdaForm.Hidden;
import java.util.function.Function;
// $FF: synthetic class
final class TestLambdaQuery$$Lambda$1 implements Function {
private TestLambdaQuery$$Lambda$1() {
}
@Hidden
public Object apply(Object var1) {
return ((User)var1).getName();
}
}
那么现在问题就在于怎么从这里解析得到getName了。
我在思考这一步的时候就陷入困境了,通过常规的反射肯定拿不到想要的信息,通过字节码去解析,工作量又太大(主要是不会)。
没办法只好看下mp源码,看人家是怎么实现的。
关键代码入口在AbstractLambdaWrapper.columnToString()方法:
进一步定位到SerializedLambda.resolve()方法:
看起来很简单,其实就是把function对象先序列化,再反序列化成SerializedLambda对象。
关于SerializedLambda对象,mp的注释也写的很清楚了:
关于为什么要copy一个SerializedLambda对象,我前面也有点疑惑,后面自己尝试去写的时候才发现,是不得不这么做,直接序列化成java.lang.invoke.SerializedLambda对象是会报错的。
下面贴一下自己的完整代码。
其中SerializedLambda类是直接从mp中copy过来然后做了一些删减的,需要注意的是这个类名不能改,不然反序列化的时候会报错。
还有就是java8提供的Function接口不支持序列化,需要自己定义一个新的支持序列化的SFunction接口
import org.springframework.util.Assert;
public class TestLambdaQuery {
public static void main(String[] args) {
MyLambdaQueryWrapper<User> lambdaQueryWrapper = new MyLambdaQueryWrapper<>();
String sql = lambdaQueryWrapper.eqSql(User::getName, "xiaoming");
Assert.isTrue("name=xiaoming".equals(sql), "expect sql is 'name=xiaoming'");
String sql2 = lambdaQueryWrapper.eqSql(User::getAge, 18);
Assert.isTrue("age=18".equals(sql2), "expect sql is 'age=18'");
}
}
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
public class User {
private String name = "xiaoming";
private Integer age = 18;
}
import java.io.Serializable;
import java.util.function.Function;
/**
* 支持序列化的 Function
*/
@FunctionalInterface
public interface SFunction<T, R> extends Function<T, R>, Serializable {
}
import java.io.*;
public class MyLambdaQueryWrapper<T> {
public String eqSql(SFunction<T, Object> function, Object value) {
SerializedLambda serializedLambda = deserialize(serialize(function));
String key = methodToProperty(serializedLambda.getImplMethodName());
return key + "=" + value;
}
/**
* 方法名转字段名,直接copy mp的代码
* 其实就是按java bean的规范,先把get、set、is前缀去掉,然后第二个字符如果不是大写,就把第一个转小写
* @param name
* @return
*/
public static String methodToProperty(String name) {
if (name.startsWith("is")) {
name = name.substring(2);
} else if (name.startsWith("get") || name.startsWith("set")) {
name = name.substring(3);
} else {
throw new RuntimeException("Error parsing property name '" + name + "'. Didn't start with 'is', 'get' or 'set'.");
}
if (name.length() == 1 || (name.length() > 1 && !Character.isUpperCase(name.charAt(1)))) {
name = name.substring(0, 1).toLowerCase(Locale.ENGLISH) + name.substring(1);
}
return name;
}
private SerializedLambda deserialize(byte[] bytes) {
try (ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes)){
@Override
protected Class<?> resolveClass(ObjectStreamClass objectStreamClass) throws IOException, ClassNotFoundException {
return objectStreamClass.getName().equals(java.lang.invoke.SerializedLambda.class.getName()) ? SerializedLambda.class : super.resolveClass(objectStreamClass);
}
}){
return (SerializedLambda) ois.readObject();
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
private byte[] serialize(SFunction<T, Object> function) {
ByteArrayOutputStream baos = new ByteArrayOutputStream(1024);
try (ObjectOutputStream oos = new ObjectOutputStream(baos)){
oos.writeObject(function);
oos.flush();
} catch (IOException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
return baos.toByteArray();
}
}
import com.baomidou.mybatisplus.core.toolkit.ClassUtils;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.SerializationUtils;
import java.io.*;
/**
* 这个类是从 {@link java.lang.invoke.SerializedLambda} 里面 copy 过来的,
* 字段信息完全一样
* 负责将一个支持序列的 Function 序列化为 SerializedLambda
*
* @author HCL
* @since 2018/05/10
*/
@SuppressWarnings("unused")
public class SerializedLambda implements Serializable {
private static final long serialVersionUID = 8025925345765570181L;
private Class<?> capturingClass;
private String functionalInterfaceClass;
private String functionalInterfaceMethodName;
private String functionalInterfaceMethodSignature;
private String implClass;
private String implMethodName;
private String implMethodSignature;
private int implMethodKind;
private String instantiatedMethodType;
private Object[] capturedArgs;
/**
* 通过反序列化转换 lambda 表达式,该方法只能序列化 lambda 表达式,不能序列化接口实现或者正常非 lambda 写法的对象
*
* @param lambda lambda对象
* @return 返回解析后的 SerializedLambda
*/
public static SerializedLambda resolve(SFunction<?, ?> lambda) {
if (!lambda.getClass().isSynthetic()) {
throw ExceptionUtils.mpe("该方法仅能传入 lambda 表达式产生的合成类");
}
try (ObjectInputStream objIn = new ObjectInputStream(new ByteArrayInputStream(SerializationUtils.serialize(lambda))) {
@Override
protected Class<?> resolveClass(ObjectStreamClass objectStreamClass) throws IOException, ClassNotFoundException {
Class<?> clazz;
try {
clazz = ClassUtils.toClassConfident(objectStreamClass.getName());
} catch (Exception ex) {
clazz = super.resolveClass(objectStreamClass);
}
return clazz == java.lang.invoke.SerializedLambda.class ? SerializedLambda.class : clazz;
}
}) {
return (SerializedLambda) objIn.readObject();
} catch (ClassNotFoundException | IOException e) {
throw ExceptionUtils.mpe("This is impossible to happen", e);
}
}
/**
* 获取接口 class
*
* @return 返回 class 名称
*/
public String getFunctionalInterfaceClassName() {
return normalizedName(functionalInterfaceClass);
}
/**
* 获取实现的 class
*
* @return 实现类
*/
public Class<?> getImplClass() {
return ClassUtils.toClassConfident(getImplClassName());
}
/**
* 获取 class 的名称
*
* @return 类名
*/
public String getImplClassName() {
return normalizedName(implClass);
}
/**
* 获取实现者的方法名称
*
* @return 方法名称
*/
public String getImplMethodName() {
return implMethodName;
}
/**
* 正常化类名称,将类名称中的 / 替换为 .
*
* @param name 名称
* @return 正常的类名
*/
private String normalizedName(String name) {
return name.replace('/', '.');
}
/**
* @return 获取实例化方法的类型
*/
public Class<?> getInstantiatedType() {
String instantiatedTypeName = normalizedName(instantiatedMethodType.substring(2, instantiatedMethodType.indexOf(';')));
return ClassUtils.toClassConfident(instantiatedTypeName);
}
/**
* @return 字符串形式
*/
@Override
public String toString() {
String interfaceName = getFunctionalInterfaceClassName();
String implName = getImplClassName();
return String.format("%s -> %s::%s",
interfaceName.substring(interfaceName.lastIndexOf('.') + 1),
implName.substring(implName.lastIndexOf('.') + 1),
implMethodName);
}
}