番外篇,jdk自带动态代理源码分析

之前和同学聊起代理模式,顺嘴提到了动态代理,就顺便看了一下源码,话不多说,开始分析,和之前一样为了方便理解,我会直接在代码中注释

这是一段很常见的动态代理代码,TestInterface是一个接口,里面只有一个test方法,TestInterfaceImpl类实现了TestInterface接口,代码也比较简单,我就不全部贴出来了

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;

public class Test implements InvocationHandler {

    public  TestInterface ti;

    public Object newProxyInstance(TestInterface ti) {
        this.ti = ti;
        //重点
        return Proxy.newProxyInstance(ti.getClass().getClassLoader(), ti.getClass().getInterfaces(), this);
    }
    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        System.out.println("执行方法前的操作");
        if(method.getName().equals("test"))
        {
            ti.test();
        }
        System.out.println("执行方法后的操作");
        return null;
    }

    public static void main(String[] args) {
        Test test =new Test ();
        TestInterface ti = (TestInterface)test.newProxyInstance(new TestInterfaceImpl());
        ti.test();
    }
}

重点是Proxy.newProxyInstance方法

public static Object newProxyInstance(ClassLoader loader,
                                          Class<?>[] interfaces,
                                          InvocationHandler h) {
        //若传入的动态代理类为空抛出异常
        Objects.requireNonNull(h);
		//获取当前系统的安全检查器
		//因为这里会进行类加载和io操作,需要进行安全检查
        final Class<?> caller = System.getSecurityManager() == null
                                    ? null
                                    : Reflection.getCallerClass();

        /*
         * Look up or generate the designated proxy class and its constructor.
         */
         //关键方法
        Constructor<?> cons = getProxyConstructor(caller, loader, interfaces);
		//此处只是调用Constructor的newInstance方法
		//把动态代理对象传入,调用对应的构造方法
		//返回一个动态代理后的实体类
		//意义不大,这里就不多介绍
        return newProxyInstance(caller, cons, h);
    }

进入关键方法getProxyConstructor

private static Constructor<?> getProxyConstructor(Class<?> caller,
                                                      ClassLoader loader,
                                                      Class<?>... interfaces)
    {
        //判断需要代理的接口是否为多个,进行简单转换
        //if else内容几乎一致,就只对前面的if内容进行分析,不再赘述
        if (interfaces.length == 1) {
            Class<?> intf = interfaces[0];
            //判断是否需要进行安全检查
            if (caller != null) {
                checkProxyAccess(caller, loader, intf);
            }
            //重点方法,我会进行单独讲解
            return proxyCache.sub(intf).computeIfAbsent(
                loader,
                (ld, clv) -> new ProxyBuilder(ld, clv.key()).build()
            );
        } else {
            // interfaces cloned
            final Class<?>[] intfsArray = interfaces.clone();
            if (caller != null) {
                checkProxyAccess(caller, loader, intfsArray);
            }
            final List<Class<?>> intfs = Arrays.asList(intfsArray);
            return proxyCache.sub(intfs).computeIfAbsent(
                loader,
                (ld, clv) -> new ProxyBuilder(ld, clv.key()).build()
            );
        }
    }

这段代码相当优雅,把lanbda表达式写的出神入化(但是我不建议大家平时这样写,可读性极差,我自己看这段代码蒙了有一会)

return proxyCache.sub(intf).computeIfAbsent(
                loader,
                (ld, clv) -> new ProxyBuilder(ld, clv.key()).build()
            );

首先我们来看computeIfAbsent方法

public V computeIfAbsent(ClassLoader cl,
                             BiFunction<
                                 ? super ClassLoader,
                                 ? super CLV,
                                 ? extends V
                                 > mappingFunction)

lanbda表达式推导的就是BiFunction接口的方法,进入BiFunction接口

public interface BiFunction<T, U, R> {
    R apply(T t, U u);
    。。。省略。。。
}

问题来了,我们可以把lanbda表达式写成下面的代码

R apply(T ld, U clv){
	return new ProxyBuilder(ld, clv.key()).build();
}

然后通过类型推导(根据computeIfAbsent方法进行推导)将代码推导为

R apply(ClassLoader ld, U clv){
	return new ProxyBuilder(ld, clv.key()).build();
}

再根据sub方法,和AbstractClassLoaderValue,Sub类进行进一步的推导

public abstract class AbstractClassLoaderValue<CLV extends AbstractClassLoaderValue<CLV, V>, V>
public final class Sub<K> extends AbstractClassLoaderValue<Sub<K>, V>
public <K> Sub<K> sub(K key) {
        return new Sub<K>(key);
    }

最终才能得到推导后的方法

R apply(ClassLoader ld, Sub clv){
	return new ProxyBuilder(ld, clv.key()).build();
}

但是,此时返回值依然是不确定的,根据build的返回值推导,为Constructor,最终的方法才能得到

Constructor apply(ClassLoader ld, Sub clv){
	return new ProxyBuilder(ld, clv.key()).build();
}

确实挺优雅的,就是太难理解,进入computeIfAbsent方法进行分析,设计的有点绕,猜测是为了解决复用性

public V computeIfAbsent(ClassLoader cl,
                             BiFunction<
                                 ? super ClassLoader,
                                 ? super CLV,
                                 ? extends V
                                 > mappingFunction) throws IllegalStateException {
        //创建map集合,注意,此处会根据类加载器加载对应的map集合
        ConcurrentHashMap<CLV, Object> map = map(cl);
        @SuppressWarnings("unchecked")
        //获取自身对象此处为sub对象
        CLV clv = (CLV) this;
        Memoizer<CLV, V> mv = null;
        while (true) {
        	//查看mv对象是否为空,如果为空(即,首次进入循环),尝试从集合中获取vsl
        	//如果可以获取,认为该类已经加载,直接返回类对象
        	//否则进行添加并赋值给val
            Object val = (mv == null) ? map.get(clv) : map.putIfAbsent(clv, mv);
            //当第一次运行时,val和mv均为null
            if (val == null) {
                if (mv == null) {
                    // create Memoizer lazily when 1st needed and restart loop
                    //创建Memoizer对象
                    //只做一件事,将三个参数赋值给自身的成员变量
                    mv = new Memoizer<>(cl, clv, mappingFunction);
                    continue;
                }
                // mv != null, therefore sv == null was a result of successful
                // putIfAbsent
                try {
                    // trigger Memoizer to compute the value
                    //主要方法
                    V v = mv.get();
                    // attempt to replace our Memoizer with the value
                    //替换元素
                    map.replace(clv, mv, v);
                    // return computed value
                    return v;
                } catch (Throwable t) {
                    // our Memoizer has thrown, attempt to remove it
                    map.remove(clv, mv);
                    // propagate exception because it's from our Memoizer
                    throw t;
                }
            } else {
                try {
                	//如果进行到这一步,说明map集合中存在已经加载好的类
                	//判断是否为Memoizer(若为Memoizer可能是多线程条件下get方法尚未完成)
                	//然后返回对应的类,或是继续调用get方法
                    return extractValue(val);
                } catch (Memoizer.RecursiveInvocationException e) {
                    // propagate recursive attempts to calculate the same
                    // value as being calculated at the moment
                    throw e;
                } catch (Throwable t) {
                    // don't propagate exceptions thrown from foreign Memoizer -
                    // pretend that there was no entry and retry
                    // (foreign computeIfAbsent invocation will try to remove it anyway)
                }
            }
            // TODO:
            // Thread.onSpinLoop(); // when available
        }
    }

get方法

public V get() throws RecursiveInvocationException {
            V v = this.v;
            if (v != null) return v;
            Throwable t = this.t;
            if (t == null) {
                synchronized (this) {
                    if ((v = this.v) == null && (t = this.t) == null) {
                        if (inCall) {
                            throw new RecursiveInvocationException();
                        }
                        inCall = true;
                        try {
                            this.v = v = Objects.requireNonNull(
                            	//前面后面都是健壮性和复用性检测
                            	//没什么好看的关键是是这里
                            	//调用前面用lanbda表达式传入发方法
                                mappingFunction.apply(cl, clv));
                        } catch (Throwable x) {
                            this.t = t = x;
                        } finally {
                            inCall = false;
                        }
                    }
                }
            }
            if (v != null) return v;
            if (t instanceof Error) {
                throw (Error) t;
            } else if (t instanceof RuntimeException) {
                throw (RuntimeException) t;
            } else {
                throw new UndeclaredThrowableException(t);
            }
        }

回到之前的build方法

Constructor<?> build() {
			//核心,返回class对象
            Class<?> proxyClass = defineProxyClass(module, interfaces);
            final Constructor<?> cons;
            try {
            	//返回对应的构造方法,没什么好看的
                cons = proxyClass.getConstructor(constructorParams);
            } catch (NoSuchMethodException e) {
                throw new InternalError(e.toString(), e);
            }
            AccessController.doPrivileged(new PrivilegedAction<Void>() {
                public Void run() {
                    cons.setAccessible(true);
                    return null;
                }
            });
            return cons;
        }

defineProxyClass方法

private static Class<?> defineProxyClass(Module m, List<Class<?>> interfaces) {
            String proxyPkg = null;     // package to define proxy class in
            int accessFlags = Modifier.PUBLIC | Modifier.FINAL;

            /*
             * Record the package of a non-public proxy interface so that the
             * proxy class will be defined in the same package.  Verify that
             * all non-public proxy interfaces are in the same package.
             */
            for (Class<?> intf : interfaces) {
                int flags = intf.getModifiers();
                if (!Modifier.isPublic(flags)) {
                    accessFlags = Modifier.FINAL;  // non-public, final
                    String pkg = intf.getPackageName();
                    if (proxyPkg == null) {
                        proxyPkg = pkg;
                    } else if (!pkg.equals(proxyPkg)) {
                        throw new IllegalArgumentException(
                                "non-public interfaces from different packages");
                    }
                }
            }

            if (proxyPkg == null) {
                // all proxy interfaces are public
                proxyPkg = m.isNamed() ? PROXY_PACKAGE_PREFIX + "." + m.getName()
                                       : PROXY_PACKAGE_PREFIX;
            } else if (proxyPkg.isEmpty() && m.isNamed()) {
                throw new IllegalArgumentException(
                        "Unnamed package cannot be added to " + m);
            }

            if (m.isNamed()) {
                if (!m.getDescriptor().packages().contains(proxyPkg)) {
                    throw new InternalError(proxyPkg + " not exist in " + m.getName());
                }
            }

            /*
             * Choose a name for the proxy class to generate.
             */
            long num = nextUniqueNumber.getAndIncrement();
            String proxyName = proxyPkg.isEmpty()
                                    ? proxyClassNamePrefix + num
                                    : proxyPkg + "." + proxyClassNamePrefix + num;

            ClassLoader loader = getLoader(m);
            trace(proxyName, m, loader, interfaces);

            /*
             * Generate the specified proxy class.
             */
            byte[] proxyClassFile = ProxyGenerator.generateProxyClass(
                    proxyName, interfaces.toArray(EMPTY_CLASS_ARRAY), accessFlags);
            try {
                Class<?> pc = UNSAFE.defineClass(proxyName, proxyClassFile,
                                                 0, proxyClassFile.length,
                                                 loader, null);
                reverseProxyCache.sub(pc).putIfAbsent(loader, Boolean.TRUE);
                return pc;
            } catch (ClassFormatError e) {
                /*
                 * A ClassFormatError here means that (barring bugs in the
                 * proxy class generation code) there was some other
                 * invalid aspect of the arguments supplied to the proxy
                 * class creation (such as virtual machine limitations
                 * exceeded).
                 */
                throw new IllegalArgumentException(e.toString());
            }
        }

其他都是做一些检验,没什么好说的,其实这么一大段重要的就下面

//生成动态代理字节码
byte[] proxyClassFile = ProxyGenerator.generateProxyClass(
                    proxyName, interfaces.toArray(EMPTY_CLASS_ARRAY), accessFlags);
            try {
            	//通过类加载器加载字节码
                Class<?> pc = UNSAFE.defineClass(proxyName, proxyClassFile,
                                                 0, proxyClassFile.length,
                                                 loader, null);
                reverseProxyCache.sub(pc).putIfAbsent(loader, Boolean.TRUE);
                return pc;

最终生成字节码也就是在内存中生成.class文件的方法,大家看看就行了,太多了,而且也没啥技术含量,方法和需要实现的接口方法内部字节码都是基本固定的,无非是把实现的接口方法转发到一个统一的方法中,再通过统一方法调用我们自己实现的invoke方法,就是实现的接口,导包的路径,类名称不同,稍有jvm常识就能自己写一个,就是费时间而已

private byte[] generateClassFile() {

        /* ============================================================
         * Step 1: Assemble ProxyMethod objects for all methods to
         * generate proxy dispatching code for.
         */

        /*
         * Record that proxy methods are needed for the hashCode, equals,
         * and toString methods of java.lang.Object.  This is done before
         * the methods from the proxy interfaces so that the methods from
         * java.lang.Object take precedence over duplicate methods in the
         * proxy interfaces.
         */
        //添加hashCode,equals,toString方法
        addProxyMethod(hashCodeMethod, Object.class);
        addProxyMethod(equalsMethod, Object.class);
        addProxyMethod(toStringMethod, Object.class);

        /*
         * Now record all of the methods from the proxy interfaces, giving
         * earlier interfaces precedence over later ones with duplicate
         * methods.
         */
        for (Class<?> intf : interfaces) {
            for (Method m : intf.getMethods()) {
                addProxyMethod(m, intf);
            }
        }

        /*
         * For each set of proxy methods with the same signature,
         * verify that the methods' return types are compatible.
         */
        for (List<ProxyMethod> sigmethods : proxyMethods.values()) {
            checkReturnTypes(sigmethods);
        }

        /* ============================================================
         * Step 2: Assemble FieldInfo and MethodInfo structs for all of
         * fields and methods in the class we are generating.
         */
        try {
            methods.add(generateConstructor());
			//开始生成代理方法
            for (List<ProxyMethod> sigmethods : proxyMethods.values()) {
                for (ProxyMethod pm : sigmethods) {

                    // add static field for method's Method object
                    //生成对应的方法访问标志
                    fields.add(new FieldInfo(pm.methodFieldName,
                        "Ljava/lang/reflect/Method;",
                         ACC_PRIVATE | ACC_STATIC));

                    // generate code for proxy method and add it
                    //生成代理方法
                    methods.add(pm.generateMethod());
                }
            }
			//添加静态块
            methods.add(generateStaticInitializer());

        } catch (IOException e) {
            throw new InternalError("unexpected I/O Exception", e);
        }
		
        if (methods.size() > 65535) {
            throw new IllegalArgumentException("method limit exceeded");
        }
        if (fields.size() > 65535) {
            throw new IllegalArgumentException("field limit exceeded");
        }

        /* ============================================================
         * Step 3: Write the final class file.
         */

        /*
         * Make sure that constant pool indexes are reserved for the
         * following items before starting to write the final class file.
         */
        //获取类名(稍后还要进行拼接)和父类名
        cp.getClass(dotToSlash(className));
        cp.getClass(superclassName);
        //获取所有接口名
        for (Class<?> intf: interfaces) {
            cp.getClass(dotToSlash(intf.getName()));
        }

        /*
         * Disallow new constant pool additions beyond this point, since
         * we are about to write the final constant pool table.
         */
        cp.setReadOnly();

        ByteArrayOutputStream bout = new ByteArrayOutputStream();
        DataOutputStream dout = new DataOutputStream(bout);

        try {
            /*
             * Write all the items of the "ClassFile" structure.
             * See JVMS section 4.1.
             */
                                        // u4 magic;
			//添加魔数
            dout.writeInt(0xCAFEBABE);
                                        // u2 minor_version;
          	//添加版本信息
            dout.writeShort(CLASSFILE_MINOR_VERSION);
                                        // u2 major_version;
       		//添加小版本信息
            dout.writeShort(CLASSFILE_MAJOR_VERSION);
			//添加常量池
            cp.write(dout);             // (write constant pool)

                                        // u2 access_flags;
            //添加标志位
            dout.writeShort(accessFlags);
                                        // u2 this_class;
            //添加类名
            dout.writeShort(cp.getClass(dotToSlash(className)));
                                        // u2 super_class;
            //添加父类名
            dout.writeShort(cp.getClass(superclassName));

                                        // u2 interfaces_count;
            //添加接口表标志位
            dout.writeShort(interfaces.length);
                                        // u2 interfaces[interfaces_count];
            //添加接口
            for (Class<?> intf : interfaces) {
                dout.writeShort(cp.getClass(
                    dotToSlash(intf.getName())));
            }

                                        // u2 fields_count;
            //添加字段表标志位
            dout.writeShort(fields.size());
                                        // field_info fields[fields_count];
          	//添加字段
            for (FieldInfo f : fields) {
                f.write(dout);
            }

                                        // u2 methods_count;
            //添加方法表标志位
            dout.writeShort(methods.size());
                                        // method_info methods[methods_count];
            //添加方法
            for (MethodInfo m : methods) {
                m.write(dout);
            }

                                         // u2 attributes_count;
            //类文件属性,稍后会在另外一个方法设置
            dout.writeShort(0); // (no ClassFile attributes for proxy classes)

        } catch (IOException e) {
            throw new InternalError("unexpected I/O Exception", e);
        }

        return bout.toByteArray();
    }

通过反射获取生成的class字节码

import java.io.FileOutputStream;
import java.io.IOException;
import java.lang.reflect.Method;

public class Test {
    private static void saveProxyFile() {
        FileOutputStream out = null;
        try {
            Class clazz = Class.forName("java.lang.reflect.ProxyGenerator");
            Method m = clazz.getDeclaredMethod("generateProxyClass", String.class, Class[].class);
            m.setAccessible(true);
            byte[] bs = (byte[]) m.invoke(null, "$Proxy0", new Class[]{TestInterface.class});
            out = new FileOutputStream("$Proxy0.class");
            out.write(bs);
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            try {
                if (out != null) {
                    out.flush();
                    out.close();
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    public static void main(String[] args) {
        saveProxyFile();
    }
}

生成的class字节码

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.lang.reflect.UndeclaredThrowableException;

public final class $Proxy0 extends Proxy implements TestInterface {
	//方法对象
    private static Method m1;
    private static Method m3;
    private static Method m2;
    private static Method m0;
	//使用父类的构造方法处理
    public $Proxy0(InvocationHandler var1) throws  {
        super(var1);
    }
	//将所有方法转发到父类h属性的invoke方法
    public final boolean equals(Object var1) throws  {
        try {
            return (Boolean)super.h.invoke(this, m1, new Object[]{var1});
        } catch (RuntimeException | Error var3) {
            throw var3;
        } catch (Throwable var4) {
            throw new UndeclaredThrowableException(var4);
        }
    }

    public final void test() throws  {
        try {
            super.h.invoke(this, m3, (Object[])null);
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }

    public final String toString() throws  {
        try {
            return (String)super.h.invoke(this, m2, (Object[])null);
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }

    public final int hashCode() throws  {
        try {
            return (Integer)super.h.invoke(this, m0, (Object[])null);
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }
	//将方法对象和实际方法关联
    static {
        try {
            m1 = Class.forName("java.lang.Object").getMethod("equals", Class.forName("java.lang.Object"));
            m3 = Class.forName("TestInterface").getMethod("test");
            m2 = Class.forName("java.lang.Object").getMethod("toString");
            m0 = Class.forName("java.lang.Object").getMethod("hashCode");
        } catch (NoSuchMethodException var2) {
            throw new NoSuchMethodError(var2.getMessage());
        } catch (ClassNotFoundException var3) {
            throw new NoClassDefFoundError(var3.getMessage());
        }
    }
}

父类的构造方法,最终所有方法都指向我们传入的类的invoke方法

protected Proxy(InvocationHandler h) {
        Objects.requireNonNull(h);
        this.h = h;
    }

获取class对象的方法,因为是native 修饰,我就不继续下去了

public native Class<?> defineClass0(String name, byte[] b, int off, int len,
                                        ClassLoader loader,
                                        ProtectionDomain protectionDomain);

至此,分析结束

你可能感兴趣的:(番外篇,jdk自带动态代理源码分析)