模仿MybatisPlus实现 lambda query

刚接触MybatisPlus那会,就觉得它提供的lambda调用方式挺酷的

比如常规写法new QueryWrapper.eq(“name”, “xiaoming”)
改成用lambda就是:
new LambdaQueryWrapper.eq(User::getName, “xiaoming”);

看起来差别不大,但是对于名字比较长的字段名,直接idea自动提示填充就行了,再也不用去copy了,也不用担心会写错啥的,修改字段名也能自动rename,总之就是很爽。

一直都觉得这是个顺理成章的功能,没有想过怎么去实现它,直到这几天想自己实现一个类似的功能,发现其实没这么简单。

正常思维是调用方都把类名和方法名传过去了,拿到解析成name不是很容易吗。但是实际上被调用的方法里只能获取到一个function而已,下面是我写的模拟代码:

public class TestLambdaQuery {
    public static void main(String[] args) {
        MyLambdaQueryWrapper<User> lambdaQueryWrapper = new MyLambdaQueryWrapper<>();

        String sql = lambdaQueryWrapper.eqSql(User::getName, "xiaoming");
        Assert.isTrue("name=xiaoming".equals(sql), "expect sql is 'name=xiaoming'");

        String sql2 = lambdaQueryWrapper.eqSql(User::getAge, 18);
        Assert.isTrue("age=18".equals(sql2), "expect sql is 'age=18'");
    }
}

public class MyLambdaQueryWrapper<T> {
    public String eqSql(Function<T, Object> function, Object value) {
        //需要在这里从function中解析得到name、age等
        return  "name="+value;
    }
}

这里的function就是个对象,
System.out.println(function.getClass());:
class com.example.demo.lambda.TestLambdaQuery$$Lambda$1/1732398722

根据反射可以推测出是TestLambdaQuery的一个内部类,实现了Function接口,并重写了apply方法。
具体可以去了解下JDK8中lambda的实现原理,推荐这篇文章:https://cloud.tencent.com/developer/article/1350313

在main方法第一行添加:
System.setProperty(“jdk.internal.lambda.dumpProxyClasses”, “J:\apps\demo\target\classes”);
即可把生成的class 保存到本地目录,再通过idea反编译打开,如下:

//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by Fernflower decompiler)
//

package com.example.demo.lambda;

import java.lang.invoke.LambdaForm.Hidden;
import java.util.function.Function;

// $FF: synthetic class
final class TestLambdaQuery$$Lambda$1 implements Function {
    private TestLambdaQuery$$Lambda$1() {
    }

    @Hidden
    public Object apply(Object var1) {
        return ((User)var1).getName();
    }
}

那么现在问题就在于怎么从这里解析得到getName了。
我在思考这一步的时候就陷入困境了,通过常规的反射肯定拿不到想要的信息,通过字节码去解析,工作量又太大(主要是不会)。

没办法只好看下mp源码,看人家是怎么实现的。
关键代码入口在AbstractLambdaWrapper.columnToString()方法:
模仿MybatisPlus实现 lambda query_第1张图片
进一步定位到SerializedLambda.resolve()方法:
模仿MybatisPlus实现 lambda query_第2张图片
看起来很简单,其实就是把function对象先序列化,再反序列化成SerializedLambda对象。

关于SerializedLambda对象,mp的注释也写的很清楚了:
模仿MybatisPlus实现 lambda query_第3张图片

关于为什么要copy一个SerializedLambda对象,我前面也有点疑惑,后面自己尝试去写的时候才发现,是不得不这么做,直接序列化成java.lang.invoke.SerializedLambda对象是会报错的。

下面贴一下自己的完整代码。
其中SerializedLambda类是直接从mp中copy过来然后做了一些删减的,需要注意的是这个类名不能改,不然反序列化的时候会报错。
还有就是java8提供的Function接口不支持序列化,需要自己定义一个新的支持序列化的SFunction接口

import org.springframework.util.Assert;

public class TestLambdaQuery {

    public static void main(String[] args) {
        MyLambdaQueryWrapper<User> lambdaQueryWrapper = new MyLambdaQueryWrapper<>();

        String sql = lambdaQueryWrapper.eqSql(User::getName, "xiaoming");
        Assert.isTrue("name=xiaoming".equals(sql), "expect sql is 'name=xiaoming'");

        String sql2 = lambdaQueryWrapper.eqSql(User::getAge, 18);
        Assert.isTrue("age=18".equals(sql2), "expect sql is 'age=18'");
    }
}

import lombok.Getter;
import lombok.Setter;

@Getter
@Setter
public class User {
    private String name = "xiaoming";
    private Integer age = 18;
}
import java.io.Serializable;
import java.util.function.Function;

/**
 * 支持序列化的 Function
 */
@FunctionalInterface
public interface SFunction<T, R> extends Function<T, R>, Serializable {
}
import java.io.*;

public class MyLambdaQueryWrapper<T> {

    public String eqSql(SFunction<T, Object> function, Object value) {
        SerializedLambda serializedLambda = deserialize(serialize(function));

        String key = methodToProperty(serializedLambda.getImplMethodName());
        return  key + "=" + value;
    }

    /**
     * 方法名转字段名,直接copy mp的代码
     * 其实就是按java bean的规范,先把get、set、is前缀去掉,然后第二个字符如果不是大写,就把第一个转小写
     * @param name
     * @return
     */
    public static String methodToProperty(String name) {
        if (name.startsWith("is")) {
            name = name.substring(2);
        } else if (name.startsWith("get") || name.startsWith("set")) {
            name = name.substring(3);
        } else {
            throw new RuntimeException("Error parsing property name '" + name + "'.  Didn't start with 'is', 'get' or 'set'.");
        }

        if (name.length() == 1 || (name.length() > 1 && !Character.isUpperCase(name.charAt(1)))) {
            name = name.substring(0, 1).toLowerCase(Locale.ENGLISH) + name.substring(1);
        }

        return name;
    }

    private SerializedLambda deserialize(byte[] bytes) {
        try (ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes)){
            @Override
            protected Class<?> resolveClass(ObjectStreamClass objectStreamClass) throws IOException, ClassNotFoundException {
                return objectStreamClass.getName().equals(java.lang.invoke.SerializedLambda.class.getName()) ? SerializedLambda.class : super.resolveClass(objectStreamClass);
            }
        }){
            return (SerializedLambda) ois.readObject();
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
    }

    private byte[] serialize(SFunction<T, Object> function) {
        ByteArrayOutputStream baos = new ByteArrayOutputStream(1024);
        try (ObjectOutputStream oos = new ObjectOutputStream(baos)){
            oos.writeObject(function);
            oos.flush();
        } catch (IOException e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
        return baos.toByteArray();
    }
}
import com.baomidou.mybatisplus.core.toolkit.ClassUtils;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.SerializationUtils;

import java.io.*;

/**
 * 这个类是从 {@link java.lang.invoke.SerializedLambda} 里面 copy 过来的,
 * 字段信息完全一样
 * 

负责将一个支持序列的 Function 序列化为 SerializedLambda

* * @author HCL * @since 2018/05/10 */
@SuppressWarnings("unused") public class SerializedLambda implements Serializable { private static final long serialVersionUID = 8025925345765570181L; private Class<?> capturingClass; private String functionalInterfaceClass; private String functionalInterfaceMethodName; private String functionalInterfaceMethodSignature; private String implClass; private String implMethodName; private String implMethodSignature; private int implMethodKind; private String instantiatedMethodType; private Object[] capturedArgs; /** * 通过反序列化转换 lambda 表达式,该方法只能序列化 lambda 表达式,不能序列化接口实现或者正常非 lambda 写法的对象 * * @param lambda lambda对象 * @return 返回解析后的 SerializedLambda */ public static SerializedLambda resolve(SFunction<?, ?> lambda) { if (!lambda.getClass().isSynthetic()) { throw ExceptionUtils.mpe("该方法仅能传入 lambda 表达式产生的合成类"); } try (ObjectInputStream objIn = new ObjectInputStream(new ByteArrayInputStream(SerializationUtils.serialize(lambda))) { @Override protected Class<?> resolveClass(ObjectStreamClass objectStreamClass) throws IOException, ClassNotFoundException { Class<?> clazz; try { clazz = ClassUtils.toClassConfident(objectStreamClass.getName()); } catch (Exception ex) { clazz = super.resolveClass(objectStreamClass); } return clazz == java.lang.invoke.SerializedLambda.class ? SerializedLambda.class : clazz; } }) { return (SerializedLambda) objIn.readObject(); } catch (ClassNotFoundException | IOException e) { throw ExceptionUtils.mpe("This is impossible to happen", e); } } /** * 获取接口 class * * @return 返回 class 名称 */ public String getFunctionalInterfaceClassName() { return normalizedName(functionalInterfaceClass); } /** * 获取实现的 class * * @return 实现类 */ public Class<?> getImplClass() { return ClassUtils.toClassConfident(getImplClassName()); } /** * 获取 class 的名称 * * @return 类名 */ public String getImplClassName() { return normalizedName(implClass); } /** * 获取实现者的方法名称 * * @return 方法名称 */ public String getImplMethodName() { return implMethodName; } /** * 正常化类名称,将类名称中的 / 替换为 . * * @param name 名称 * @return 正常的类名 */ private String normalizedName(String name) { return name.replace('/', '.'); } /** * @return 获取实例化方法的类型 */ public Class<?> getInstantiatedType() { String instantiatedTypeName = normalizedName(instantiatedMethodType.substring(2, instantiatedMethodType.indexOf(';'))); return ClassUtils.toClassConfident(instantiatedTypeName); } /** * @return 字符串形式 */ @Override public String toString() { String interfaceName = getFunctionalInterfaceClassName(); String implName = getImplClassName(); return String.format("%s -> %s::%s", interfaceName.substring(interfaceName.lastIndexOf('.') + 1), implName.substring(implName.lastIndexOf('.') + 1), implMethodName); } }

你可能感兴趣的:(技术研究,技术踩坑)