Java:用ASM实现动态代理实体类

1. 实体类动态代理的分析

这篇文章不是专门讲解 ASM 的,有兴趣的可以去了解一下。ASM官方文档传送门。感觉英文吃力的可以下载中文文档 ASM中文文档

这里我们只需要知道添加 ASM 的依赖

implementation 'org.ow2.asm:asm:7.0'
implementation 'org.ow2.asm:asm-commons:7.0'
implementation 'org.ow2.asm:asm-util:7.0'

在实现实体类的动态代理前,我们先要分析接口和实体类的区分

  1. 接口的方法默认都是 public 的,所有实现接口的类默认都会全部继承所有接口;而实体类的方法有 private、protected、public 和 default 的区别
  2. 实现接口可以直接使用默认无参构造函数;而继承实体类有多个构造函数需要继承,并且需要制定一个构造函数来实例化代理对象
  3. 接口的方法都不是 final 的;而实体类的方法可能是 final 的
  4. 接口的方法都不是 static 的;而实体类的方法可能是 static 的

再梳理一下,模仿JDK的动态代理的设计思路,实现动态代理实体类所需要的步骤

  1. 定义 InvocationHandler 类,用于代理对象的方法调用时的回调
  2. 根据 classloader 和 class 判断是否有缓存,如果有则直接从缓存获取。否则再次生成class并在同一个 classloader 加载的话会出现问题
  3. 判断被代理对象是否是final的,不是final才进行下一步
  4. 用 ASM 新建一个类,包名和被代理类一致,采用Proxy_前缀命名。设置代理类的修饰符和继承关系
  5. 添加成员变量 InvocationHandler,便于后面方法调用时的回调
  6. 添加成员变量 InvocationHandler 的 setter 方法
  7. 添加构造器,继承自卑代理类的构造器,里面只是简单的调用 super()
  8. 添加调用 InvocationHandler 的 invoke 方法的方法
  9. 添加方法,筛选出 public、projected、default 的方法,方法内直接调用第8步创建的方法
  10. 添加静态代码块,用于初始化新建的静态方法字段
  11. 用 classloader 生成 class,并放入缓存
  12. 根据 class 实例化对象,并且调用 setter 方法传入 InvocationHandler

使用 ASM 的时候,有几点注意事项和技巧提前说明一下

  1. 因为 ASM 编写代码时非常麻烦,所以尽可能把步骤封装成 java 方法,然后 ASM 只需要创建这个 java 方法和调用这个方法。用 ASM 创建 java 方法是非常方便的,可以用 ASMifier 直接打印出来所需的代码。
  2. 如果新建或者删除了成员变量,那么就必须在构造方法,也就是 "" 方法中增加或删除对对应字段的赋值操作
  3. 如果新建或者删除了静态变量,那么就必须在静态代码块,也就是 "" 方法中增加或删除对对应字段的赋值操作
  4. 在局部变量表中,除了 long 和 double 两种基本类型需要占用两个槽外,其他类型一律都只占用一个槽
  5. 所有基本类型的装箱和拆箱操作,都必须手动完成
  6. 需要入栈常量时,尽量用 BxPUSH 3 来代替 xCONST_3,比如 BIPUSH 3 代替 ICONST_3。这样便于提高代码的适配率。
  7. 在 ASM 中,经常要用到类名的内部名形式(innerName),其实就是将 "." 替换为了 "/"
  8. 需要计算类型描述符或方法描述符时,ASM 提供的 Type 类非常好用,基本上可以避免手动拼接类型描述符
  9. 如果要读取方法的所有参数,最好抽象一个 java 方法,参数为 Object...,然后所有参数只需要传参就行了。这里要注意基本类型的装箱操作

2. 实现实体类动态代理

首先模仿 JDK 实现一个代理方法的回调接口

public interface InvocationHandler {

    /**
     * @param proxy 动态生成的代理对象
     * @param method 调用的方法
     * @param args 调用的参数
     * @return 该方法的返回值
     * @throws Throwable
     */
    Object invoke(Object proxy, Method method, Object[] args) throws Throwable;

}

然后有一个 ClassVisitor,用于读取被代理类的一些数据

public class TargetClassVisitor extends ClassVisitor {

    private boolean isFinal;
    private List methods = new ArrayList<>();
    private List declaredMethods = new ArrayList<>();
    private List constructors = new ArrayList<>();

    public TargetClassVisitor() {
        super(Proxy.ASM_VERSION);
    }

    @Override
    public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
        super.visit(version, access, name, signature, superName, interfaces);
        if ((access & Opcodes.ACC_FINAL) == Opcodes.ACC_FINAL){
            isFinal = true;
        }
        if (superName != null) {
            List beans = initMethodBeanByParent(superName);
            if (beans != null && !beans.isEmpty()) {
                for (MethodBean bean : beans) {
                    if (!methods.contains(bean)) {
                        methods.add(bean);
                    }
                }
            }
        }
    }

    @Override
    public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
        if ("".equals(name)){
            // 构造方法
            MethodBean constructor = new MethodBean(access, name, descriptor);
            constructors.add(constructor);
        } else if (!"".equals(name)) {
            // 其他方法
            if ((access & Opcodes.ACC_FINAL) == Opcodes.ACC_FINAL
                    || (access & Opcodes.ACC_STATIC) == Opcodes.ACC_STATIC) {
                return super.visitMethod(access, name, descriptor, signature, exceptions);
            }
            MethodBean methodBean = new MethodBean(access, name, descriptor);
            declaredMethods.add(methodBean);
            if ((access & Opcodes.ACC_PUBLIC) == Opcodes.ACC_PUBLIC) {
                methods.add(methodBean);
            }
        }
        return super.visitMethod(access, name, descriptor, signature, exceptions);
    }

    public boolean isFinal() {
        return isFinal;
    }

    public List getMethods() {
        return methods;
    }

    public List getDeclaredMethods() {
        return declaredMethods;
    }

    public List getConstructors() {
        return constructors;
    }

    private List initMethodBeanByParent(String superName){
        try {
            if (superName != null && !superName.isEmpty()){
                ClassReader reader = new ClassReader(superName);
                TargetClassVisitor visitor = new TargetClassVisitor();
                reader.accept(visitor, ClassReader.SKIP_DEBUG);
                List beans = new ArrayList<>();
                for (MethodBean methodBean : visitor.methods) {
                    // 跳过 final 和 static
                    if ((methodBean.access & Opcodes.ACC_FINAL) == Opcodes.ACC_FINAL
                            || (methodBean.access & Opcodes.ACC_STATIC) == Opcodes.ACC_STATIC) {
                        continue;
                    }
                    // 只要 public
                    if ((methodBean.access & Opcodes.ACC_PUBLIC) == Opcodes.ACC_PUBLIC) {
                        beans.add(methodBean);
                    }
                }
                return beans;
            }
        }catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

    public static class MethodBean {

        public int access;
        public String methodName;
        public String methodDesc;

        public MethodBean() {
        }

        public MethodBean(int access, String methodName, String methodDesc) {
            this.access = access;
            this.methodName = methodName;
            this.methodDesc = methodDesc;
        }

        @Override
        public boolean equals(Object obj) {
            if (obj == null){
                return false;
            }
            if (!(obj instanceof MethodBean)){
                return false;
            }
            MethodBean bean = (MethodBean) obj;
            if (access == bean.access
                    && methodName != null
                    && bean.methodName != null
                    && methodName.equals(bean.methodName)
                    && methodDesc != null
                    && bean.methodDesc != null
                    && methodDesc.equals(bean.methodDesc)){
                return true;
            }
            return false;
        }
    }
}

最后就是代理类的实现了,这个类比较复杂,从入口方法 newProxyInstance 开始看

public class Proxy {

    public static final int ASM_VERSION = Opcodes.ASM7;
    public static final int ASM_JDK_VERSION = Opcodes.V1_7;

    // 动态生成代理类的前缀
    public static final String PROXY_CLASSNAME_PREFIX = "$Proxy_";
    // 字段名
    private static final String FIELD_INVOCATIONHANDLER = "invocationHandler";
    // 方法名
    private static final String METHOD_SETTER = "setInvocationHandler";
    private static final String METHOD_INVOKE = "invokeInvocationHandler";
    private static final String METHOD_INVOKE_DESC = "(Ljava/lang/Object;Ljava/lang/reflect/Method;[Ljava/lang/Object;)Ljava/lang/Object;";
    private static final String METHOD_FIELD_PREFIX = "method";
    // 缓存容器,防止生成同一个Class文件在同一个ClassLoader加载崩溃的问题
    private static final Map> proxyClassCache = new HashMap<>();

    /**
     * 缓存已经生成的代理类的Class,key值根据 classLoader 和 targetClass 共同决定
     */
    private static void saveProxyClassCache(ClassLoader classLoader, Class targetClass, Class proxyClass) {
        String key = classLoader.toString() + "_" + targetClass.getName();
        proxyClassCache.put(key, proxyClass);
    }

    /**
     * 从缓存中取得代理类的Class,如果没有则返回 null
     */
    private static Class getProxyClassCache(ClassLoader classLoader, Class targetClass) {
        String key = classLoader.toString() + "_" + targetClass.getName();
        return proxyClassCache.get(key);
    }

    /**
     * 返回一个动态创建的代理类,此类继承自 targetClass
     *
     * @param classLoader       从哪一个ClassLoader加载Class
     * @param invocationHandler 代理类中每一个方法调用时的回调接口
     * @param targetClass       被代理对象
     * @param targetConstructor 被代理对象的某一个构造器,用于决定代理对象实例化时采用哪一个构造器
     * @param targetParam       被代理对象的某一个构造器的参数,用于实例化构造器
     * @return
     */
    public static Object newProxyInstance(ClassLoader classLoader,
                                          InvocationHandler invocationHandler,
                                          Class targetClass,
                                          Constructor targetConstructor,
                                          Object... targetParam) {
        if (classLoader == null || targetClass == null || invocationHandler == null) {
            throw new IllegalArgumentException("argument is null");
        }
        try {
            // 查看是否有缓存
            Class proxyClass = getProxyClassCache(classLoader, targetClass);
            if (proxyClass != null) {
                // 实例化代理对象
                return newInstance(proxyClass, invocationHandler, targetConstructor, targetParam);
            }
            // 获取目标类的一些数据
            ClassReader reader = new ClassReader(targetClass.getName());
            TargetClassVisitor targetClassVisitor = new TargetClassVisitor();
            reader.accept(targetClassVisitor, ClassReader.SKIP_DEBUG);
            // 判断是否是FINAL的
            if (targetClassVisitor.isFinal()) {
                throw new IllegalArgumentException("class is final");
            }
            // 开始生成代理类
            ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
            String newClassName = generateProxyClassName(targetClass);
            String newClassInnerName = newClassName.replace(".", "/");
            String targetClassName = targetClass.getName();
            String targetClassInnerName = Type.getInternalName(targetClass);
            // 创建类
            newClass(writer, newClassInnerName, targetClassInnerName);
            // 添加 InvocationHandler 字段
            addField(writer);
            // 添加 InvocationHandler 的setter
            addSetterMethod(writer, newClassInnerName);
            // 添加构造器,直接调用 super
            List constructors = targetClassVisitor.getConstructors();
            addConstructor(writer, constructors, targetClassInnerName);
            // 添加调用 InvocationHandler 的方法
            addInvokeMethod(writer, newClassInnerName);
            // 添加继承的public方法和目标类的protected、default方法
            List methods = targetClassVisitor.getMethods();
            List declaredMethods = targetClassVisitor.getDeclaredMethods();
            Map methodsMap = new HashMap<>();
            Map declaredMethodsMap = new HashMap<>();
            int methodNameIndex = 0;
            methodNameIndex = addMethod(writer, newClassInnerName, targetClass.getMethods(),
                    methods, true, methodNameIndex, methodsMap);
            addMethod(writer, newClassInnerName, targetClass.getDeclaredMethods(),
                    declaredMethods, false, methodNameIndex, declaredMethodsMap);
            // 添加静态代码块的初始化
            addStaticInitBlock(writer, targetClassName, newClassInnerName, methodsMap, declaredMethodsMap);
            // 生成二进制数据
            byte[] bytes = writer.toByteArray();
            // 保存到文件,用于debug调试
//            File outputFile = new File("/Users/jm/Downloads/Demo/" + newClassInnerName + ".class");
//            save2File(outputFile, bytes);
            // 从指定ClassLoader加载Class
            proxyClass = transfer2Class(classLoader, bytes);
            // 缓存
            saveProxyClassCache(classLoader, targetClass, proxyClass);
            // 实例化代理对象
            return newInstance(proxyClass, invocationHandler, targetConstructor, targetParam);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * 生成代理类的类名生成规则
     */
    private static String generateProxyClassName(Class targetClass) {
        return targetClass.getPackage().getName() + "." + PROXY_CLASSNAME_PREFIX + targetClass.getSimpleName();
    }

    /**
     * 根据被代理类的构造器,构造代理类对象。生成代理类的实例时调用其setter方法
     */
    private static Object newInstance(Class proxyClass,
                                      InvocationHandler invocationHandler,
                                      Constructor targetConstructor,
                                      Object... targetParam) throws Exception {
        Class[] parameterTypes = targetConstructor.getParameterTypes();
        Constructor constructor = proxyClass.getConstructor(parameterTypes);
        Object instance = constructor.newInstance(targetParam);
        Method setterMethod = proxyClass.getDeclaredMethod(METHOD_SETTER, InvocationHandler.class);
        setterMethod.setAccessible(true);
        setterMethod.invoke(instance, invocationHandler);
        return instance;
    }

    /**
     * 创建类
     */
    private static void newClass(ClassWriter writer, String newClassName, String targetClassName) throws Exception {
        int access = Opcodes.ACC_PUBLIC | Opcodes.ACC_FINAL;
        writer.visit(ASM_JDK_VERSION, access, newClassName, null, targetClassName, null);
    }

    /**
     * 添加 invocationHandler 字段
     */
    private static void addField(ClassWriter writer) throws Exception {
        FieldVisitor fieldVisitor = writer.visitField(Opcodes.ACC_PRIVATE, FIELD_INVOCATIONHANDLER,
                Type.getDescriptor(InvocationHandler.class), null, null);
        fieldVisitor.visitEnd();
    }

    /**
     * 添加 invocationHandler 的 setter 方法
     */
    private static void addSetterMethod(ClassWriter writer, String owner) throws Exception {
        String methodDesc = "(" + Type.getDescriptor(InvocationHandler.class) + ")V";
        MethodVisitor methodVisitor = writer.visitMethod(Opcodes.ACC_PUBLIC, METHOD_SETTER, methodDesc, null, null);
        methodVisitor.visitCode();
        methodVisitor.visitVarInsn(Opcodes.ALOAD, 0);
        methodVisitor.visitVarInsn(Opcodes.ALOAD, 1);
        methodVisitor.visitFieldInsn(Opcodes.PUTFIELD, owner, FIELD_INVOCATIONHANDLER,
                Type.getDescriptor(InvocationHandler.class));
        methodVisitor.visitInsn(Opcodes.RETURN);
        methodVisitor.visitMaxs(2, 2);
        methodVisitor.visitEnd();
    }

    /**
     * 添加构造器
     */
    private static void addConstructor(ClassWriter writer, List constructors,
                                       String targetClassInnerName) throws Exception {
        for (MethodBean constructor : constructors) {
            Type[] argumentTypes = Type.getArgumentTypes(constructor.methodDesc);
            MethodVisitor methodVisitor = writer.visitMethod(Opcodes.ACC_PUBLIC, "",
                    constructor.methodDesc, null, null);
            methodVisitor.visitCode();
            methodVisitor.visitVarInsn(Opcodes.ALOAD, 0);

            // 对每一个参数,都将对应局部变量表的位置入栈
            for (int i = 0; i < argumentTypes.length; i++) {
                Type argumentType = argumentTypes[i];
                if (argumentType.equals(Type.BYTE_TYPE)
                        || argumentType.equals(Type.BOOLEAN_TYPE)
                        || argumentType.equals(Type.CHAR_TYPE)
                        || argumentType.equals(Type.SHORT_TYPE)
                        || argumentType.equals(Type.INT_TYPE)) {
                    methodVisitor.visitVarInsn(Opcodes.ILOAD, i + 1);
                } else if (argumentType.equals(Type.LONG_TYPE)) {
                    methodVisitor.visitVarInsn(Opcodes.LLOAD, i + 1);
                } else if (argumentType.equals(Type.FLOAT_TYPE)) {
                    methodVisitor.visitVarInsn(Opcodes.FLOAD, i + 1);
                } else if (argumentType.equals(Type.DOUBLE_TYPE)) {
                    methodVisitor.visitVarInsn(Opcodes.DLOAD, i + 1);
                } else {
                    methodVisitor.visitVarInsn(Opcodes.ALOAD, i + 1);
                }
            }

            // 调用super() 构造器
            methodVisitor.visitMethodInsn(Opcodes.INVOKESPECIAL, targetClassInnerName, "", constructor.methodDesc, false);
            methodVisitor.visitInsn(Opcodes.RETURN);
            methodVisitor.visitMaxs(argumentTypes.length + 1, argumentTypes.length + 1);
            methodVisitor.visitEnd();
        }
    }

    /**
     * 添加调用 invocationHandler 的 invoke 方法
     */
    private static void addInvokeMethod(ClassWriter writer, String owner) throws Exception {
        MethodVisitor methodVisitor = writer.visitMethod(Opcodes.ACC_PRIVATE | Opcodes.ACC_VARARGS,
                METHOD_INVOKE, METHOD_INVOKE_DESC, null, null);
        methodVisitor.visitCode();
        // 异常处理
        Label label0 = new Label();
        Label label1 = new Label();
        Label label2 = new Label();
        methodVisitor.visitTryCatchBlock(label0, label1, label2, Type.getInternalName(Throwable.class));
        methodVisitor.visitLabel(label0);
        // 取到 invocationHandler 字段并入栈
        methodVisitor.visitVarInsn(Opcodes.ALOAD, 0);
        methodVisitor.visitFieldInsn(Opcodes.GETFIELD, owner, FIELD_INVOCATIONHANDLER,
                Type.getDescriptor(InvocationHandler.class));
        // 将三个参数对应的局部变量表位置入栈
        methodVisitor.visitVarInsn(Opcodes.ALOAD, 1);
        methodVisitor.visitVarInsn(Opcodes.ALOAD, 2);
        methodVisitor.visitVarInsn(Opcodes.ALOAD, 3);
        String handlerName = Type.getInternalName(InvocationHandler.class);
        String handlerMethodName = "invoke";
        String handlerDesc = "(Ljava/lang/Object;Ljava/lang/reflect/Method;[Ljava/lang/Object;)Ljava/lang/Object;";
        // 调用 invocationHandler.invoke 方法
        methodVisitor.visitMethodInsn(Opcodes.INVOKEINTERFACE, handlerName, handlerMethodName, handlerDesc, true);
        // 正常返回
        methodVisitor.visitLabel(label1);
        methodVisitor.visitInsn(Opcodes.ARETURN);
        // 异常处理
        methodVisitor.visitLabel(label2);
        methodVisitor.visitFrame(Opcodes.F_SAME1, 0, null, 1,
                new Object[]{Type.getInternalName(Throwable.class)});
        methodVisitor.visitVarInsn(Opcodes.ASTORE, 4);
        methodVisitor.visitVarInsn(Opcodes.ALOAD, 4);
        methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Throwable.class),
                "printStackTrace", "()V", false);
        methodVisitor.visitInsn(Opcodes.ACONST_NULL);
        methodVisitor.visitInsn(Opcodes.ARETURN);
        methodVisitor.visitMaxs(4, 5);
        methodVisitor.visitEnd();
    }

    /**
     * 添加继承的方法或目标类本身的方法
     */
    private static int addMethod(ClassWriter writer, String newClassInnerName,
                                 Method[] methods, List methodBeans,
                                 boolean isPublic, int methodNameIndex,
                                 Map map) throws Exception {
        for (int i = 0; i < methodBeans.size(); i++) {
            MethodBean methodBean = methodBeans.get(i);
            // 跳过final 和 static 的方法
            if ((methodBean.access & Opcodes.ACC_FINAL) == Opcodes.ACC_FINAL
                    || (methodBean.access & Opcodes.ACC_STATIC) == Opcodes.ACC_STATIC) {
                continue;
            }
            // 满足指定的修饰符
            int access = -1;
            if (isPublic) {
                // public 方法
                if ((methodBean.access & Opcodes.ACC_PUBLIC) == Opcodes.ACC_PUBLIC) {
                    access = Opcodes.ACC_PUBLIC;
                }
            } else {
                // protected 方法
                if ((methodBean.access & Opcodes.ACC_PROTECTED) == Opcodes.ACC_PROTECTED) {
                    access = Opcodes.ACC_PROTECTED;
                } else if ((methodBean.access & Opcodes.ACC_PUBLIC) == 0
                        && (methodBean.access & Opcodes.ACC_PROTECTED) == 0
                        && (methodBean.access & Opcodes.ACC_PRIVATE) == 0) {
                    access = 0;
                }
            }
            if (access == -1) {
                continue;
            }
            // 匹配对应的方法
            int methodIndex = findSomeMethod(methods, methodBean);
            if (methodIndex == -1) {
                continue;
            }
            // 将新建字段的后缀索引和对应方法数组真实的索引连接起来,方便后面初始化静态代码块时使用
            map.put(methodNameIndex, methodIndex);
            // 添加method对应的字段
            String fieldName = METHOD_FIELD_PREFIX + methodNameIndex;
            FieldVisitor fieldVisitor = writer.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_STATIC,
                    fieldName, Type.getDescriptor(Method.class), null, null);
            fieldVisitor.visitEnd();
            // 添加方法的调用
            addMethod(writer, newClassInnerName, methodBean, access, methodNameIndex);
            methodNameIndex++;
        }
        return methodNameIndex;
    }


    /**
     * 实现方法的调用
     */
    private static void addMethod(ClassWriter writer, String newClassInnerName,
                                  MethodBean methodBean, int access, int methodNameIndex) throws Exception {
        MethodVisitor methodVisitor = writer.visitMethod(access, methodBean.methodName,
                methodBean.methodDesc, null, null);
        methodVisitor.visitCode();
        methodVisitor.visitVarInsn(Opcodes.ALOAD, 0);
        // 区分静态或者是非静态方法调用
        if ((methodBean.access & Opcodes.ACC_STATIC) == Opcodes.ACC_STATIC) {
            methodVisitor.visitInsn(Opcodes.ACONST_NULL);
        } else {
            methodVisitor.visitVarInsn(Opcodes.ALOAD, 0);
        }
        // 获取新建的方法字段
        methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, newClassInnerName,
                METHOD_FIELD_PREFIX + methodNameIndex, Type.getDescriptor(Method.class));
        Type[] argumentTypes = Type.getArgumentTypes(methodBean.methodDesc);
        // 实例化数组,容量对应方法的参数个数
        methodVisitor.visitIntInsn(Opcodes.BIPUSH, argumentTypes.length);
        methodVisitor.visitTypeInsn(Opcodes.ANEWARRAY, Type.getInternalName(Object.class));
        // 计算局部变量表的位置,其中 double 和 long 占用两个槽,其他占用一个槽
        int start = 1;
        int stop = start;
        // 布局变量表入栈,基本类型需要装箱
        for (int i = 0; i < argumentTypes.length; i++) {
            Type type = argumentTypes[i];
            if (type.equals(Type.BYTE_TYPE)) {
                stop = start + 1;
                methodVisitor.visitInsn(Opcodes.DUP);
                // 放入数组的下标位置
                methodVisitor.visitIntInsn(Opcodes.BIPUSH, i);
                // 局部变量表的索引
                methodVisitor.visitVarInsn(Opcodes.ILOAD, start);
                methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Byte.class),
                        "valueOf", "(B)Ljava/lang/Byte;", false);
                methodVisitor.visitInsn(Opcodes.AASTORE);
            } else if (type.equals(Type.SHORT_TYPE)) {
                stop = start + 1;
                methodVisitor.visitInsn(Opcodes.DUP);
                // 放入数组的下标位置
                methodVisitor.visitIntInsn(Opcodes.BIPUSH, i);
                // 局部变量表的索引
                methodVisitor.visitVarInsn(Opcodes.ILOAD, start);
                methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Short.class),
                        "valueOf", "(S)Ljava/lang/Short;", false);
                methodVisitor.visitInsn(Opcodes.AASTORE);
            } else if (type.equals(Type.CHAR_TYPE)) {
                stop = start + 1;
                methodVisitor.visitInsn(Opcodes.DUP);
                // 放入数组的下标位置
                methodVisitor.visitIntInsn(Opcodes.BIPUSH, i);
                // 局部变量表的索引
                methodVisitor.visitVarInsn(Opcodes.ILOAD, start);
                methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Character.class),
                        "valueOf", "(C)Ljava/lang/Character;", false);
                methodVisitor.visitInsn(Opcodes.AASTORE);
            } else if (type.equals(Type.INT_TYPE)) {
                stop = start + 1;
                methodVisitor.visitInsn(Opcodes.DUP);
                // 放入数组的下标位置
                methodVisitor.visitIntInsn(Opcodes.BIPUSH, i);
                // 局部变量表的索引
                methodVisitor.visitVarInsn(Opcodes.ILOAD, start);
                methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Integer.class),
                        "valueOf", "(I)Ljava/lang/Integer;", false);
                methodVisitor.visitInsn(Opcodes.AASTORE);
            } else if (type.equals(Type.FLOAT_TYPE)) {
                stop = start + 1;
                methodVisitor.visitInsn(Opcodes.DUP);
                // 放入数组的下标位置
                methodVisitor.visitIntInsn(Opcodes.BIPUSH, i);
                // 局部变量表的索引
                methodVisitor.visitVarInsn(Opcodes.FLOAD, start);
                methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Float.class),
                        "valueOf", "(F)Ljava/lang/Float;", false);
                methodVisitor.visitInsn(Opcodes.AASTORE);
            } else if (type.equals(Type.DOUBLE_TYPE)) {
                stop = start + 2;
                methodVisitor.visitInsn(Opcodes.DUP);
                // 放入数组的下标位置
                methodVisitor.visitIntInsn(Opcodes.BIPUSH, i);
                // 局部变量表的索引
                methodVisitor.visitVarInsn(Opcodes.DLOAD, start);
                methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Double.class),
                        "valueOf", "(D)Ljava/lang/Double;", false);
                methodVisitor.visitInsn(Opcodes.AASTORE);
            } else if (type.equals(Type.LONG_TYPE)) {
                stop = start + 2;
                methodVisitor.visitInsn(Opcodes.DUP);
                // 放入数组的下标位置
                methodVisitor.visitIntInsn(Opcodes.BIPUSH, i);
                // 局部变量表的索引
                methodVisitor.visitVarInsn(Opcodes.LLOAD, start);
                methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Long.class),
                        "valueOf", "(J)Ljava/lang/Long;", false);
                methodVisitor.visitInsn(Opcodes.AASTORE);
            } else if (type.equals(Type.BOOLEAN_TYPE)) {
                stop = start + 1;
                methodVisitor.visitInsn(Opcodes.DUP);
                // 放入数组的下标位置
                methodVisitor.visitIntInsn(Opcodes.BIPUSH, i);
                // 局部变量表的索引
                methodVisitor.visitVarInsn(Opcodes.ILOAD, start);
                methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Boolean.class),
                        "valueOf", "(Z)Ljava/lang/Boolean;", false);
                methodVisitor.visitInsn(Opcodes.AASTORE);
            } else {
                stop = start + 1;
                methodVisitor.visitInsn(Opcodes.DUP);
                // 放入数组的下标位置
                methodVisitor.visitIntInsn(Opcodes.BIPUSH, i);
                // 局部变量表的索引
                methodVisitor.visitVarInsn(Opcodes.ALOAD, start);
                methodVisitor.visitInsn(Opcodes.AASTORE);
            }
            start = stop;
        }
        // 调用 invokeInvocationHandler 方法
        methodVisitor.visitMethodInsn(Opcodes.INVOKESPECIAL, newClassInnerName,
                METHOD_INVOKE, METHOD_INVOKE_DESC, false);
        // 处理返回情况,基本类型需要拆箱
        Type returnType = Type.getReturnType(methodBean.methodDesc);
        if (returnType.equals(Type.BYTE_TYPE)) {
            methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Byte.class));
            methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Byte.class),
                    "byteValue", "()B", false);
            methodVisitor.visitInsn(Opcodes.IRETURN);
        } else if (returnType.equals(Type.BOOLEAN_TYPE)) {
            methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Boolean.class));
            methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Boolean.class),
                    "booleanValue", "()Z", false);
            methodVisitor.visitInsn(Opcodes.IRETURN);
        } else if (returnType.equals(Type.CHAR_TYPE)) {
            methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Character.class));
            methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Character.class),
                    "charValue", "()C", false);
            methodVisitor.visitInsn(Opcodes.IRETURN);
        } else if (returnType.equals(Type.SHORT_TYPE)) {
            methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Short.class));
            methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Short.class),
                    "shortValue", "()S", false);
            methodVisitor.visitInsn(Opcodes.IRETURN);
        } else if (returnType.equals(Type.INT_TYPE)) {
            methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Integer.class));
            methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Integer.class),
                    "intValue", "()I", false);
            methodVisitor.visitInsn(Opcodes.IRETURN);
        } else if (returnType.equals(Type.LONG_TYPE)) {
            methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Long.class));
            methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Long.class),
                    "longValue", "()J", false);
            methodVisitor.visitInsn(Opcodes.LRETURN);
        } else if (returnType.equals(Type.FLOAT_TYPE)) {
            methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Float.class));
            methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Float.class),
                    "floatValue", "()F", false);
            methodVisitor.visitInsn(Opcodes.FRETURN);
        } else if (returnType.equals(Type.DOUBLE_TYPE)) {
            methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Double.class));
            methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Double.class),
                    "doubleValue", "()D", false);
            methodVisitor.visitInsn(Opcodes.DRETURN);
        } else if (returnType.equals(Type.VOID_TYPE)) {
            methodVisitor.visitInsn(Opcodes.RETURN);
        } else {
            methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, returnType.getInternalName());
            methodVisitor.visitInsn(Opcodes.ARETURN);
        }
        methodVisitor.visitMaxs(8, 37);
        methodVisitor.visitEnd();
    }

    /**
     * 添加静态初始代码块
     */
    private static void addStaticInitBlock(ClassWriter writer, String targetClassName,
                                           String newClassInnerName, Map methodsMap,
                                           Map declaredMethodsMap) throws Exception {
        String exceptionClassName = Type.getInternalName(ClassNotFoundException.class);
        MethodVisitor methodVisitor = writer.visitMethod(Opcodes.ACC_STATIC, "",
                "()V", null, null);
        methodVisitor.visitCode();
        // 开始异常处理
        Label label0 = new Label();
        Label label1 = new Label();
        Label label2 = new Label();
        methodVisitor.visitTryCatchBlock(label0, label1, label2, exceptionClassName);
        methodVisitor.visitLabel(label0);
        // 给继承的方法添加对应的字段初始化
        for (Map.Entry entry : methodsMap.entrySet()) {
            Integer key = entry.getKey();
            Integer value = entry.getValue();
            methodVisitor.visitLdcInsn(targetClassName);
            methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Class.class),
                    "forName", "(Ljava/lang/String;)Ljava/lang/Class;", false);
            methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class),
                    "getMethods", "()[Ljava/lang/reflect/Method;", false);
            methodVisitor.visitIntInsn(Opcodes.BIPUSH, value);
            methodVisitor.visitInsn(Opcodes.AALOAD);
            methodVisitor.visitFieldInsn(Opcodes.PUTSTATIC, newClassInnerName,
                    METHOD_FIELD_PREFIX + key, Type.getDescriptor(Method.class));
        }
        // 给目标类本身的方法添加对应的字段初始化
        for (Map.Entry entry : declaredMethodsMap.entrySet()) {
            Integer key = entry.getKey();
            Integer value = entry.getValue();
            methodVisitor.visitLdcInsn(targetClassName);
            methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Class.class),
                    "forName", "(Ljava/lang/String;)Ljava/lang/Class;", false);
            methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class),
                    "getDeclaredMethods", "()[Ljava/lang/reflect/Method;", false);
            methodVisitor.visitIntInsn(Opcodes.BIPUSH, value);
            methodVisitor.visitInsn(Opcodes.AALOAD);
            methodVisitor.visitFieldInsn(Opcodes.PUTSTATIC, newClassInnerName,
                    METHOD_FIELD_PREFIX + key, Type.getDescriptor(Method.class));
        }

        methodVisitor.visitLabel(label1);
        Label label3 = new Label();
        methodVisitor.visitJumpInsn(Opcodes.GOTO, label3);
        methodVisitor.visitLabel(label2);
        methodVisitor.visitFrame(Opcodes.F_SAME1, 0, null, 1,
                new Object[]{exceptionClassName});
        methodVisitor.visitVarInsn(Opcodes.ASTORE, 0);
        methodVisitor.visitVarInsn(Opcodes.ALOAD, 0);
        methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, exceptionClassName,
                "printStackTrace", "()V", false);
        methodVisitor.visitLabel(label3);
        methodVisitor.visitFrame(Opcodes.F_SAME, 0, null, 0, null);
        methodVisitor.visitInsn(Opcodes.RETURN);
        methodVisitor.visitMaxs(2, 1);
        methodVisitor.visitEnd();
    }

    /**
     * 找到相等方法的索引
     */
    private static int findSomeMethod(Method[] methods, MethodBean methodBean) {
        for (int i = 0; i < methods.length; i++) {
            if (equalsMethod(methods[i], methodBean)) {
                return i;
            }
        }
        return -1;
    }

    /**
     * 判断 {@link Method} 和 {@link MethodBean} 是否相等
     */
    private static boolean equalsMethod(Method method, MethodBean methodBean) {
        if (method == null && methodBean == null) {
            return true;
        }
        if (method == null || methodBean == null) {
            return false;
        }
        try {
            if (!method.getName().equals(methodBean.methodName)) {
                return false;
            }
            if (!Type.getReturnType(method).equals(Type.getReturnType(methodBean.methodDesc))) {
                return false;
            }
            Type[] argumentTypes1 = Type.getArgumentTypes(method);
            Type[] argumentTypes2 = Type.getArgumentTypes(methodBean.methodDesc);
            if (argumentTypes1.length != argumentTypes2.length) {
                return false;
            }
            for (int i = 0; i < argumentTypes1.length; i++) {
                if (!argumentTypes1[i].equals(argumentTypes2[i])) {
                    return false;
                }
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
        }
        return false;
    }

    private static void save2File(File file, byte[] bytes) {
        try {
            if (!file.getParentFile().exists()) {
                file.getParentFile().mkdirs();
            }
            OutputStream out = new FileOutputStream(file);
            out.write(bytes);
            out.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 将字节数组转换为 Class
     */
    private static Class transfer2Class(ClassLoader classLoader, byte[] bytes) {
        try {
            Class cl = Class.forName("java.lang.ClassLoader");
            Method defineClassMethod = cl.getDeclaredMethod("defineClass",
                    new Class[]{String.class, byte[].class, int.class, int.class});
            defineClassMethod.setAccessible(true);
            Class clazz = (Class) defineClassMethod.invoke(classLoader,
                    new Object[]{null, bytes, 0, bytes.length});
            return clazz;
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }
}

3. 使用方式

使用方法跟JDK的方法类似,只是需要额外指定被代理对象实例化的构造器,这是因为实体类可能会有多个构造器。

private static void proxyByASM() {
    try {
        Demo demo = new Demo();
        Class clazz = Demo.class;
        // 指定被代理对象的构造器,内部会自动转换为代理对象的构造器
        Constructor constructor = clazz.getConstructor(new Class[]{});
        Object[] constructorParam = new Object[]{};
        // 指定方法回调的接口
        InvocationHandler invocationHandler = new InvocationHandler() {
            @Override
            public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                System.out.println("before:" + method.getName());
                // 记住这儿是调用的被代理对象的方法,所以传参是 demo 而不是 proxy
                method.setAccessible(true);
                Object result = method.invoke(demo, args);
                System.out.println("after:" + method.getName());
                return result;
            }
        };
        Object proxy = Proxy.newProxyInstance(clazz.getClassLoader(), invocationHandler, clazz, constructor, constructorParam);
        // 分别测试 public、protected、default的方法
        ((Demo) proxy).publicDemo();
        ((Demo) proxy).protectedDemo();
        ((Demo) proxy).defaultDemo();
        // 测试有返回值的方法
        ((Demo) proxy).haha();
        // 测试继承的方法
        ((Demo) proxy).superPublic();
        System.out.println(((Demo) proxy).toString());
    } catch (Exception e) {
        e.printStackTrace();
    }
}

你可能感兴趣的:(Java:用ASM实现动态代理实体类)