Java实现RPC(服务对象使用注解并自动注入)

  • 使用到的技术:
  1. 注解和反射机制
  2. 包扫描以及jar包扫描
  3. CGlib动态代理
  4. 类似于spring框架的控制反转依赖自动注入技术
  • 目录结构:
    Java实现RPC(服务对象使用注解并自动注入)_第1张图片
  • RPCclass注解
    @Retention(RetentionPolicy.RUNTIME)
    @Target(ElementType.TYPE)
    public @interface RPCclass {
        boolean auto() default true;
    }

     

  •  RPCmethod注解
    @Retention(RetentionPolicy.RUNTIME)
    @Target(ElementType.METHOD)
    public @interface RPCmethod {
        String remoteName();
    }

     

  • RPClock
    public interface RPClock {
        public void wakeUpLock();
    }

     

  • RPCMethodDefinition(对RPC服务的包装,RPC服务器最终是通过反射该方法得到结果返回给RPC客户端)
    import java.lang.reflect.Method;
    
    public class RPCMethodDefinition {
        //要通过反射调用的类
        private Class klass;
        //要通过反射执行方法的对象
        private Object object;
        //要执行的方法
        private Method method;
    
        public Class getKlass() {
            return klass;
        }
    
        public RPCMethodDefinition setKlass(Class klass) {
            this.klass = klass;
            return this;
        }
    
        public Object getObject() {
            return object;
        }
    
        public RPCMethodDefinition setObject(Object object) {
            this.object = object;
            return this;
        }
    
        public Method getMethod() {
            return method;
        }
    
        public RPCMethodDefinition setMethod(Method method) {
            this.method = method;
            return this;
        }
    }

     

  • RPCProxy(相当于RPC客户端,每一个RPCProxy都对应一个RPC服务器,从这里取得的代理对象就可以直接远程调用RPC服务器的方法)
    import com.google.gson.Gson;
    import com.mec.rpc.annotation.RPCmethod;
    import com.mec.rpc.exception.MethodNotHaveRPCAnnotation;
    import com.mec.rpc.exception.RPCNotFoundProxyObject;
    import com.mec.rpc.exception.RPCProxyReadOutTimeException;
    import com.mec.uitl.GsonUitl;
    import net.sf.cglib.proxy.Enhancer;
    import net.sf.cglib.proxy.MethodInterceptor;
    import net.sf.cglib.proxy.MethodProxy;
    
    import java.io.DataInputStream;
    import java.io.DataOutputStream;
    import java.io.IOException;
    import java.lang.reflect.Method;
    import java.net.Socket;
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.Map;
    
    public class RPCProxy {
        //RPC服务器IP
        private String serverIp;
        //RPC服务器端口
        private int port;
        //可以将对象序列化(实际上是转为json字符串)的一种工具
        private static Gson gson;
        //设置超时时间
        private int waitingTime;
        //RPC服务器返回结果后通知的对象(这个是我自己实际工程中的应用,可以忽略)
        private Map wakeUpMap;
    
        static {
            //相当于gson = new GsonBuilder().create();
            gson = GsonUitl.gson;
        }
    
        public RPCProxy(String serverIp, int port, int waitingTime) {
            this.serverIp = serverIp;
            this.port = port;
            this.waitingTime = waitingTime;
            this.wakeUpMap = new HashMap();
        }
    
        public String getServerIp() {
            return serverIp;
        }
    
        public int getport() {
            return port;
        }
    
        public RPCProxy setwaitingTime(int waitingTime) {
            this.waitingTime = waitingTime;
            return this;
        }
    
        //通过Object取得代理(由于客户端要通过RPC调用的方法是已知的,因此不存在再运行过程中动态代理)
        public  T getProxy(Object object) {
            return getProxy(object.getClass());
        }
    
        //通过Class取得代理
        public  T getProxy(Class klass) {
            //这是CGlib代理的一般步骤
            Enhancer enhancer = new Enhancer();
            enhancer.setSuperclass(klass);
            enhancer.setCallback(new MethodInterceptor() {
                //这是通过代理对象执行方法时真正执行的代码
                public Object intercept(Object o, Method method, Object[] objects, MethodProxy methodProxy) throws Throwable {
                    //因为本人工程的需要要避免对hashCode、equals和toSting方法的远程调用
                    if(method.getName().equals("hashCode") || method.getName().equals("equals") || method.getName().equals("toString")) {
                       return methodProxy.invokeSuper(o, objects);
                    }
    
                    //要求要被代理的方法必须带有RPCmethod注解,否则抛出异常
                    if(!method.isAnnotationPresent(RPCmethod.class)) {
                        throw new MethodNotHaveRPCAnnotation("方法[" + method + "]没有RPCmethod注解");
                    }
    
                    //取得被代理方法的上的注解
                    RPCmethod rpcMethod = method.getAnnotation(RPCmethod.class);
                    //取得被代理方法的远程调用方法的名字
                    String remoteName = rpcMethod.remoteName();
                    //将每个参数序列化并依次装入列表里
                    ArrayList parameter = new ArrayList();
                    int parameterCount = objects.length;
                    for(int i = 0; i < parameterCount; i++) {
                        parameter.add(gson.toJson(objects[i]));
                    }
    
                    //和RPC服务器建立通信信道
                    Socket socket = new Socket(serverIp, port);
                    DataInputStream dis = new DataInputStream(socket.getInputStream());
                    DataOutputStream dos = new DataOutputStream(socket.getOutputStream());
    
                    //初始化时间控制线程
                    Object lock = new Object();
                    ControlThread controlThread = new ControlThread(lock, dis, dos);
                    synchronized (lock) {
                        new Thread(controlThread).start();
                        lock.wait();
                    }
    
                    //首先发送远程调用方法的remoteName
                    dos.writeUTF(remoteName);
                    //如果参数大于0,则把参数列表序列化并发送
                    if(parameterCount > 0) {
                        //序列化为json字符串
                        String str = gson.toJson(parameter);
                        //将字符串转换为二进制发送,需先发送一个头部表明需要接收的长度
                        byte[]  bytes = str.getBytes();
                        dos.writeInt(bytes.length);
                        dos.write(bytes);
                    }
    
                    Object result = null;
                    try {
                        //等待RPC服务器返回结果
                        int length = dis.readInt();
                        byte[] bytes = new byte[length];
                        dis.readFully(bytes, 0 , length);
    
                         String reStr = new String(bytes);
                        if (reStr.equals("null")) {
                            return null;
                        }
    
                        result = gson.fromJson(reStr.trim(), method.getGenericReturnType());
                    } catch (Exception e) {
                        //发生异常有三种情况,
                        //一、无法和服务器建立连接
                        //二、通信过程中和服务器断开连接
                        //三、服务器处理请求超时
                        e.printStackTrace();
                        throw new RPCProxyReadOutTimeException("向服务器【" + serverIp + "】请求数据超时");
                    } finally {
                        //无论是否成功都要通知相应的RPClock
                        RPClock rpClock = wakeUpMap.get(o);
                        if(rpClock != null) {
                            rpClock.wakeUpLock();
                        }
    
                        //关闭输入信道
                        synchronized (dis) {
                            try {
                                dis.close();
                            } catch (IOException e) {
                            }
                        }
    
                        //关闭输出信道
                        synchronized (dos) {
                            try {
                                dos.close();
                            } catch (IOException e) {
                            }
                        }
    
                        //关闭套接字
                        try {
                            socket.close();
                        } catch (IOException e) {
                        }
                    }
                    return result;
                }
            });
    
            Object cglibProxy  = enhancer.create();
            //每一个代理对象对应一个RPClock
            wakeUpMap.put(cglibProxy, null);
            return (T) cglibProxy;
        }
    
        //通过添加RPClock的方法
        public void putRPClock(Object object, RPClock rpcLock) throws RPCNotFoundProxyObject {
            boolean found = false;
            for(Object o: wakeUpMap.keySet()) {
                if(o.equals(object)) {
                    found = true;
                }
            }
    
            if(!found) {
                throw new RPCNotFoundProxyObject("参数" + object + "未在" + this +"代理中找到");
            }
    
            wakeUpMap.put(object, rpcLock);
        }
    
        //远程调用时间控制线程
        private class ControlThread implements Runnable {
            private DataInputStream dis;
            private DataOutputStream dos;
            private Object lock;
    
            public ControlThread(Object lock, DataInputStream dis, DataOutputStream dos) {
                this.lock = lock;
                this.dis = dis;
                this.dos = dos;
            }
    
            public void run() {
                synchronized (lock) {
                    lock.notify();
                }
    
                try {
                    Thread.sleep(waitingTime);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
    
                //超出时间以后直接若socket未关闭则强制关闭
                try {
                    synchronized (dis) {
                        dis.close();
                    }
                    synchronized (dos) {
                        dos.close();
                    }
                } catch (IOException e) {
                }
            }
        }
    }
    

     

  • RPCServer(RPC服务器,拥有可以被远程掉用的所有方法并装载一个Map里面,侦听到方法调用的请求就开启一个线程去处理,有点类似于NIO)
    import com.google.gson.Gson;
    import com.mec.rpc.annotation.RPCclass;
    import com.mec.rpc.annotation.RPCmethod;
    import com.mec.rpc.exception.ClassNotHaveRPCAnnotation;
    import com.mec.rpc.exception.RPCMethodParametersLengthMismatching;
    import com.mec.rpc.exception.RPCServiceNotRegister;
    import com.mec.rpc.exception.RPCServiceObjectNotLoad;
    import com.mec.uitl.GsonUitl;
    import com.mec.uitl.PackageScanner;
    import com.mec.uitl.TypeUitl;
    
    import java.io.DataInputStream;
    import java.io.DataOutputStream;
    import java.io.IOException;
    import java.lang.reflect.InvocationTargetException;
    import java.lang.reflect.Method;
    import java.lang.reflect.Type;
    import java.net.ServerSocket;
    import java.net.Socket;
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.Map;
    
    public class RPCServer implements Runnable {
        //用来装可以被远程调用的方法
        private Map serviceMap;
        //控制服务器侦听客户端线程的变量
        private volatile boolean goon;
        private ServerSocket serverSocket;
        //json序列化工具
        private static Gson gson;
        //用来得到传过来的实参
        private static Type type;
    
    
        static {
            gson = GsonUitl.gson; //        gson = new GsonBuilder().create();
            type = TypeUitl.type; //        type = new TypeToken>() {}.getType();
        }
    
        public RPCServer() {
            serviceMap = new HashMap();
            goon = false;
        }
    
        //启动RPC服务器
        public void start(int port) throws IOException {
            serverSocket = new ServerSocket(port);
    
            goon = true;
            new Thread(this).start();
        }
    
        //关闭服务器
        public void stop() {
            if(serverSocket == null || serverSocket.isClosed()) {
                return;
            }
            goon = false;
            try {
                serverSocket.close();
            } catch (IOException e) {
                e.printStackTrace();
            } finally {
                serverSocket = null;
            }
        }
    
    
        //服务器主线程,侦听到就启动一个线程去处理RPC请求
        public void run() {
            while (goon) {
                try {
                    Socket socket = serverSocket.accept();
                    new Thread(new DealRequest(socket)).start();
                } catch (IOException e) {
                    if(goon) {
                        e.printStackTrace();
                    } else {
                        break;
                    }
                }
            }
        }
    
    
        //处理远程掉用的请求并返回结果
        private class DealRequest implements Runnable {
            private Socket socket;
            private DataOutputStream dos;
            private DataInputStream dis;
    
            public DealRequest(Socket socket) {
                this.socket = socket;
                try {
                    dis = new DataInputStream(socket.getInputStream());
                    dos = new DataOutputStream(socket.getOutputStream());
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
    
            //这里的读和写与RPCProxy写和读一一对应
            public void run() {
                try {
                    //读取远程调用的remoteName
                    String remoteName = dis.readUTF();
                    //从RPC服务集合里取出对应的RPCMethodDefinition
                    RPCMethodDefinition rdf = serviceMap.get(remoteName);
                    if(rdf == null) {
                        throw new RPCServiceNotRegister("服务" + remoteName + "未注册");
                    }
    
                    Object object = rdf.getObject();
    
                    if(object == null) {
                        throw new RPCServiceObjectNotLoad("服务" + remoteName + "对象未加载");
                    }
    
                    Method method = rdf.getMethod();
                    Type[] parameters = method.getGenericParameterTypes();
                    int parameterCount = parameters.length;
                    Object result = null;
    
                    if(parameterCount <= 0) {
                        //如果参数为0,直接反射执行该方法
                        result = method.invoke(object);
                    } else {
                        //如果参数不为0.先接收序列化的实参,再进行反序列化
                        int length = dis.readInt();
                        byte[] bytes = new byte[length];
                        dis.readFully(bytes, 0 , length);
                        String parameter = new String(bytes);
                        ArrayList parameterList = gson.fromJson(parameter, type);
                        if(parameterCount != parameterList.size()) {
                            throw new RPCMethodParametersLengthMismatching("服务" + remoteName + "参数个数不匹配" + parameterCount + " " + parameterList.size());
                        }
    
                        Object[] argus = new Object[parameterCount];
    
                        for (int i = 0; i < parameterCount; i++) {
                            argus[i] = gson.fromJson(parameterList.get(i), parameters[i]);
                        }
                        result = method.invoke(object, argus);
    
                    }
    
                    byte[] bytes= gson.toJson(result).getBytes();
                    dos.writeInt(bytes.length);
                    dos.write(bytes);
                    //确认数据全部发送完毕之后关闭相应资源
                    dos.flush();
            } catch (IOException e) {
                    e.printStackTrace();
                } catch (RPCServiceNotRegister rpcServiceNotRegister) {
                    rpcServiceNotRegister.printStackTrace();
                } catch (RPCServiceObjectNotLoad rpcServiceObjectNotLoad) {
                    rpcServiceObjectNotLoad.printStackTrace();
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                } catch (InvocationTargetException e) {
                    e.printStackTrace();
                } catch (RPCMethodParametersLengthMismatching rpcMethodParametersLengthMismatching) {
                    rpcMethodParametersLengthMismatching.printStackTrace();
                } finally {
                    try {
                        dis.close();
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                    try {
                        dos.close();
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                    try {
                        socket.close();
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
    
            }
        }
    
        //自动扫描指定包下的所有类,当检测到带有RPCclass注解的类时实行自动装入
        public void scanPackage(String path) {
            //自制的包扫描工具,可以遍历指定包下的所有类
            new PackageScanner() {
                public void dealClass(Class klass) {
                    try {
                        if (klass.isAnnotationPresent(RPCclass.class)) {
                            addService(klass);
                        }
                    } catch (ClassNotHaveRPCAnnotation classNotHaveRPCAnnotation) {
                        return;
                    }
                }
            }.packageScanner(path);
        }
    
        //提供给外部增加RPCmethod的方法,参数为object时,表明该object是其调用对应类的方法时要反射执行方法的对象
        public void addService(Object object) throws ClassNotHaveRPCAnnotation {
            Class klass = object.getClass();
    
            if(!klass.isAnnotationPresent(RPCclass.class)) {
                throw new ClassNotHaveRPCAnnotation("class[" + klass + "]没有RPCclass注解");
            }
    
            addService(klass, object);
        }
    
        //若只给一个Class对象,如果该对象是自动装载,则自动创造一个新的对象,若该对象是另外装入,则RPCMethodDefinition里的object暂为null
        public void addService(Class klass) throws ClassNotHaveRPCAnnotation {
            if(!klass.isAnnotationPresent(RPCclass.class)) {
                throw new ClassNotHaveRPCAnnotation("class[" + klass + "]没有RPCclass注解");
            }
    
            RPCclass rpCclass = (RPCclass) klass.getAnnotation(RPCclass.class);
            try {
                Object object = rpCclass.auto() ? klass.newInstance() : null;
                addService(klass, object);
            } catch (InstantiationException e) {
                e.printStackTrace();
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            }
        }
    
        private void addService(Class klass, Object object) {
                Method[] methods = klass.getDeclaredMethods();
                for(Method method : methods) {
                    if(!method.isAnnotationPresent(RPCmethod.class)) {
                        continue;
                    }
    
                    RPCmethod rpcMethod = method.getAnnotation(RPCmethod.class);
                    String remoteName = rpcMethod.remoteName();
                    RPCMethodDefinition rdf = serviceMap.get(remoteName);
                    //若rdf不为null,则表明是延迟加载的对象,只需设置对应的object就可以
                    if(rdf != null) {
                        rdf.setObject(object);
                        continue;
                    }
                    serviceMap.put(remoteName, new RPCMethodDefinition()
                            .setKlass(klass)
                            .setMethod(method)
                            .setObject(object));
                }
        }
    
    }
    
     

     

  • PackageScanner(包扫描工具)
    import java.io.File;
    import java.io.IOException;
    import java.net.JarURLConnection;
    import java.net.URISyntaxException;
    import java.net.URL;
    import java.util.Enumeration;
    import java.util.jar.JarEntry;
    import java.util.jar.JarFile;
    
    public abstract class PackageScanner {
        private ClassLoader classLoader;
    
        public PackageScanner() {
        }
    
        //实参为一个类,则得到其所在包的路径
        public void packageScanner(Class clazz) {
            this.packageScanner(clazz.getPackage().getName());
        }
    
        //实参为包路径,例如com.mec
        public void packageScanner(String rootPackage) {
            String rootPath = rootPackage.replace(".", "/");
            this.classLoader = Thread.currentThread().getContextClassLoader();
    
            try {
                //通过这种可以得到包录路径下所有的类,包括jar包里
                Enumeration urls = this.classLoader.getResources(rootPath);
    
                while(urls.hasMoreElements()) {
                    URL url = (URL)urls.nextElement();
                    String jarProtocol = url.getProtocol();
                    //普通类的jarProtocol为file,jar包里的类为jar
                    if (jarProtocol.equals("file")) {
                        try {
                            File file = new File(url.toURI());
                            this.scanPackage(file.getAbsolutePath(), rootPackage);
                        } catch (URISyntaxException var7) {
                            var7.printStackTrace();
                        }
                    } else if (jarProtocol.equals("jar")) {
                        this.scanPackage(url);
                    }
                }
            } catch (IOException var8) {
                var8.printStackTrace();
            }
    
        }
    
        //这是抛给外部来处理遍历到到的类的方法
        public abstract void dealClass(Class var1);
    
        //处理不是jar包里的类,利用递归遍历的方法全部过一遍
        private void scanPackage(String path, String packageName) {
            File curFile = new File(path);
            if (curFile.exists()) {
                File[] files = curFile.listFiles();
                File[] var5 = files;
                int var6 = files.length;
    
                for(int var7 = 0; var7 < var6; ++var7) {
                    File file = var5[var7];
                    if (file.isDirectory()) {
                        this.scanPackage(file.getAbsolutePath(), packageName + "." + file.getName());
                    } else if (file.isFile() && file.getName().endsWith(".class")) {
                        String fileName = file.getName();
                        int dotInde = fileName.indexOf(".class");
                        fileName = fileName.substring(0, dotInde);
                        String className = packageName + "." + fileName;
    
                        try {
                            Class klass = Class.forName(className);
                            this.dealClass(klass);
                        } catch (ClassNotFoundException var13) {
                            var13.printStackTrace();
                        }
                    }
                }
    
            }
        }
    
        //扫描jar包里的类
        private void scanPackage(URL url) {
            try {
                JarURLConnection jarURLConnection = (JarURLConnection)url.openConnection();
                JarFile jarFile = jarURLConnection.getJarFile();
                Enumeration jarEntries = jarFile.entries();
    
                while(jarEntries.hasMoreElements()) {
                    JarEntry jarEntry = (JarEntry)jarEntries.nextElement();
                    if (!jarEntry.isDirectory() && jarEntry.getName().endsWith(".class")) {
                        String className = jarEntry.getName();
                        int dotIndex = className.indexOf(".class");
                        className = className.substring(0, dotIndex).replace("/", ".");
                        if (className.startsWith("com.mec")) {
                            try {
                                Class klass = Class.forName(className);
                                this.dealClass(klass);
                            } catch (ClassNotFoundException var9) {
                                var9.printStackTrace();
                            }
                        }
                    }
                }
            } catch (IOException var10) {
                var10.printStackTrace();
            }
    
        }
    }

     

  • Studen(测试用的类)
    import com.mec.rpc.annotation.RPCclass;
    import com.mec.rpc.annotation.RPCmethod;
    
    @RPCclass
    public class Student {
        private String name;
    
        @RPCmethod(remoteName = "getStudentName")
        public String getName() {
            return name == null ? "小明" : name;
        }
    
        @RPCmethod(remoteName = "setStudentName")
        public Student setName(String name) {
            this.name = name;
            return this;
        }
    }

     

  • ServerTest(服务器程序)
    import java.io.IOException;
    
    public class ServerTest {
        public static void main(String[] args) {
            //创建一个RPC服务器并注册一个学生为小绿
            RPCServer rpcServer = new RPCServer();
            Student student = new Student();
            student.setName("小绿");
            rpcServer.scanPackage("com.mec.Test");
            rpcServer.addService(student);
            try {
                rpcServer.start(54196);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

     

  • ClientTest(客户端程序)
    import com.mec.rpc.core.RPCProxy;
    
    public class ClientTest {
        public static void main(String[] args) {
            RPCProxy localProxy = new RPCProxy("localhost", 54196, 5000);
            Student student = localProxy.getProxy(Student.class);
            System.out.println(student.getName());
        }
    }

     

  • 结果
    Java实现RPC(服务对象使用注解并自动注入)_第2张图片

 如果失败的话输出的应该是小明,输出的是小绿则证明成功。

通过包扫描和注解的方式可以实现自动注入而无须手动添加RPC服务,包扫描技术非常方便的让我们处理类,通过反射可以获得有相应注解的类和方法是关键。

实际上在工程应用中哪个方法是通过RPC服务器调用的方法程序员是非常清楚的,屏蔽的只是是底层的通信细节,因此有时我们需要对方法进行异常处理,不让它影响整体程序的运行。
 

                            try {
                                studentName = student.getName();
                            }catch (Throwable e) {
                                if(e instanceof RPCProxyReadOutTimeException) {
                                    System.out.println("成功捕获超时异常");
                                }
                            }

 

你可能感兴趣的:(自制工具)