手写RPC框架

RPC框架核心组件

对于RPC框架简洁模式下,主要有以下角色,暂且抛开心跳机制以及负载均衡等复杂策略,我们先来自己实现一个RPC框架,后面我们再深入理解。
手写RPC框架_第1张图片

注册中心

RegisterServiceVo

package com.cover.rpc.remote.vo;

import java.io.Serializable;

// 注册中心注册服务的实体类
public class RegisterServiceVo implements Serializable {

    // 服务提供者的ip地址
    private final String host;
    
    // 服务提供者的端口
    private final int port;
    
    public RegisterServiceVo(String host, int port) {
        this.host = host;
        this.port = port;
    }
    
    public String getHost() {
        return host;
    }
    
    public int getPort() {
        return port;
    }
}

RegisterCenter

package com.cover.rpc.rpc.reg.service;

import org.springframework.stereotype.Component;
import com.cover.rpc.remote.vo.RegisterServiceVo;

import javax.annotation.PostConstruct;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

// 服务注册中心,服务提供者在启动时需要注册中心登记自己的信息
@Component
public class RegisterCenter {
    // key表示服务名,value代表
    private static final Map<String, Set<RegisterServiceVo>> serviceHolder 
            = new HashMap<>();
    // 注册服务端口号
    private int port;

    /**
     * 服务注册,考虑到可能有多个提供者同时注册,进行加索
     */
    private static synchronized void registerService(String serviceName, String host, int port) {
        // 获得当前服务的已有地址集合
        Set<RegisterServiceVo> serviceVoSet = serviceHolder.get(serviceName);
        if (serviceVoSet == null) {
            // 已有地址集合为空,新增集合
            serviceVoSet = new HashSet<>();
            serviceHolder.put(serviceName, serviceVoSet);
        }
        
        // 将新的服务提供者加入集合
        serviceVoSet.add(new RegisterServiceVo(host,port));

        System.out.println("服务已注册[" + serviceName + "]," + "地址[ " + host + "], 端口[" + port + "]" );
    }


    /**
     * 取出服务提供者
     */
    private static Set<RegisterServiceVo> getService(String serviceName) {
        return serviceHolder.get(serviceName);
    }

    /**
     * 处理服务请求的任务,无非就是两种服务
     * 1.服务注册服务
     * 2.服务查询服务
     */
    private static class ServerTask implements Runnable {

        private Socket client = null;
        
        
        public ServerTask(Socket client) {
            this.client = client;
        }
        @Override
        public void run() {
            try (
                    ObjectInputStream inputStream = new ObjectInputStream(client.getInputStream());
                    ObjectOutputStream outputStream = new ObjectOutputStream(client.getOutputStream())
            ) {
                
                // 检查当前请求是注册服务还是获取服务
                boolean isGetService = inputStream.readBoolean();
                // 服务查询服务,获取服务提供者
                if (isGetService) {
                    String serviceName = inputStream.readUTF();
                    // 取出服务提供者
                    Set<RegisterServiceVo> result = getService(serviceName);
                    // 返回给客户端
                    outputStream.writeObject(result);
                    outputStream.flush();
                    System.out.println("将已注册的服务[" + serviceName + "提供给客户端");
                } else {
                    // 服务注册服务
                    //取得新服务提供方的ip和端口
                    String serviceName = inputStream.readUTF();
                    String host = inputStream.readUTF();
                    int port = inputStream.readInt();
                    // 注册中心保存
                    registerService(serviceName, host, port);
                    outputStream.writeBoolean(true);
                    outputStream.flush();

                }
            } catch (IOException e) {
                throw new RuntimeException(e);
            } finally {
                try {
                    client.close();
                } catch (IOException e) {
                    e.printStackTrace();
//                    throw new RuntimeException(e);
                }
            }
        }
    }

    // 启动注册服务
    public void startService() throws IOException {
        ServerSocket serverSocket = new ServerSocket();
        serverSocket.bind(new InetSocketAddress(port));
        System.out.println("服务注册中心 on :" + port + ": 运行");
        try {
            while (true) {
                new Thread(new ServerTask(serverSocket.accept())).start();
            }
        } finally {
            serverSocket.close();
        }
    }

    @PostConstruct
    public void init() {
        this.port = 9999;
        new Thread(new Runnable() {
            @Override
            public void run() {
                try {
                    startService();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }).start();
    }
}

文件结构
手写RPC框架_第2张图片

服务提供者

UserInfo

package com.cover.rpc.remote.vo;

import java.io.Serializable;

public class UserInfo implements Serializable {

    private final String name;
    
    private final String phone;
    
    public UserInfo(String name, String phone) {
        this.name = name;
        this.phone = phone;
    }

    public String getName() {
        return name;
    }

    public String getPhone() {
        return phone;
    }
}

SendSms

package com.cover.rpc.remote;


import com.cover.rpc.remote.vo.UserInfo;

// 短信息发送接口
public interface SendSms {
    
    boolean sendMail(UserInfo userInfo);
}

RegisterServiceWithRegCenter

package com.cover.rpc.rpc.base;

import org.springframework.stereotype.Service;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

// 注册服务,引入了服务的注册和发现机制
@Service
public class RegisterServiceWithRegCenter {
    
    // 本地提供服务的一个名单,用缓存实现
    private static final Map<String, Class> serviceCache = new ConcurrentHashMap<>();
    
    // 往远程注册服务器注册本服务,同时在本地注册本服务
    public void regRemote(String serviceName, String host, int port, Class impl) throws IOException {
        // 登记到注册中心
        Socket socket = null;
        ObjectOutputStream output = null;
        ObjectInputStream input = null;
        try {
            socket = new Socket();
            socket.connect(new InetSocketAddress("127.0.0.1", 9999));
            
            output = new ObjectOutputStream(socket.getOutputStream());
            // 注册服务
            output.writeBoolean(false);
            // 提供的服务名
            output.writeUTF(serviceName);
            // 服务提供方的IP
            output.writeUTF(host);
            // 服务提供方的端口
            output.writeInt(port);
            
            output.flush();
            
            input = new ObjectInputStream(socket.getInputStream());
            if (input.readBoolean()) {
                System.out.println("服务[" + serviceName + "]注册成功!");
            }
            
            // 可提供服务放入本地缓存
            serviceCache.put(serviceName, impl);
        } catch (Exception e) {
            e.printStackTrace();
            if (socket != null) {
                socket.close();
            }
            
            if (output != null) {
                output.close();
            }
            
            if (input != null) {
                input.close();
            }
        }
        
    }
    
    
    // 获取服务
    public Class getLocalService(String serviceName) {
        return serviceCache.get(serviceName);
    }
}

RpcServerFrame

package com.cover.rpc.rpc.base;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;

// RPC框架的服务端部分
@Service
public class RpcServerFrame {
    
    @Autowired
    private RegisterServiceWithRegCenter registerServiceWithRegCenter;
    
    // 服务的端口号
    private int port;
    
    
    // 处理服务请求任务
    private static class ServerTask implements Runnable {

        private Socket socket;
        
        private RegisterServiceWithRegCenter registerServiceWithRegCenter;
        
        public ServerTask(Socket client, RegisterServiceWithRegCenter registerServiceWithRegCenter) {
            this.socket = client;
            this.registerServiceWithRegCenter = registerServiceWithRegCenter;
        }
        
        @Override
        public void run() {
            try (
                    ObjectInputStream inputStream = new ObjectInputStream(socket.getInputStream());
                    ObjectOutputStream outputStream = new ObjectOutputStream(socket.getOutputStream())
            ) {

                // 方法所在类名接口名
                String serviceName = inputStream.readUTF();
                // 方法的名字
                String methodName = inputStream.readUTF();
                // 方法的入参  类型
                Class<?>[] paramTypes = (Class<?>[]) inputStream.readObject();
                // 方法的入参的值
                Object[] args = (Object[]) inputStream.readObject();

                // 从容器中拿到服务的Class对象
                Class serviceClass = registerServiceWithRegCenter.getLocalService(serviceName);
                if (serviceClass == null) {
                    throw new ClassNotFoundException(serviceName + "not found");
                }

                // 通过反射,执行实际的服务
                Method method = serviceClass.getMethod(methodName, paramTypes);
                Object result = method.invoke(serviceClass.newInstance(), args);

                // 将服务的执行结果通知调用者
                outputStream.writeObject(result);
                outputStream.flush();

            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                try {
                    socket.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }
    
    public void startService(String serviceName, String host, int port, Class impl) throws IOException {
        ServerSocket serverSocket = new ServerSocket();
        serverSocket.bind(new InetSocketAddress(port));
        System.out.println("RPC server on :" + port + ":运行");
        registerServiceWithRegCenter.regRemote(serviceName, host, port, impl);
        try {
            while (true) {
                new Thread(new ServerTask(serverSocket.accept(), registerServiceWithRegCenter)).start();
            }
        } finally {
            serverSocket.close();
        }
    }
}

SendSmsImpl

package com.cover.rpc.rpc.sms;

import com.cover.rpc.remote.SendSms;
import com.cover.rpc.remote.vo.UserInfo;

public class SendSmsImpl implements SendSms {
    @Override
    public boolean sendMail(UserInfo userInfo) {
        try {
            Thread.sleep(50);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        System.out.println("已发送短信息给 :" + userInfo.getName() + "到[" + userInfo.getPhone());
        return true;
    }
}

SmsRpcServer

package com.cover.rpc.rpc.sms;

import com.cover.rpc.remote.SendSms;
import com.cover.rpc.rpc.base.RpcServerFrame;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import javax.annotation.PostConstruct;
import java.io.IOException;
import java.util.Random;

@Service
public class SmsRpcServer {
    
    @Autowired
    private RpcServerFrame rpcServerFrame;
    
    @PostConstruct
    public void server() throws IOException {
        Random r = new Random();
        int port = 8778 + r.nextInt(100);
        rpcServerFrame.startService(SendSms.class.getName(), "127.0.0.1", port, SendSmsImpl.class);
    }
}

文件结构
手写RPC框架_第3张图片

服务消费者

BeanConfig

package com.cover.rpc.client.config;

import com.cover.rpc.client.rpc.RpcClientFrame;
import com.cover.rpc.remote.SendSms;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.io.IOException;

@Configuration
public class BeanConfig {
    
    @Autowired
    private RpcClientFrame rpcClientFrame;

    @Bean
    public SendSms getSmsService() throws IOException, ClassNotFoundException {
        return rpcClientFrame.getRemoteProxyObject(SendSms.class);
    }

}

RpcClientFrame

package com.cover.rpc.client.rpc;

import com.cover.rpc.remote.vo.RegisterServiceVo;
import org.springframework.stereotype.Service;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.Set;

// RPC框架的客户端代理部分
@Service
public class RpcClientFrame {

    // 远程服务的代理对象,参数为客户端要调用的服务
    public static <T> T getRemoteProxyObject(final Class<?> serviceInterface) throws IOException,
            ClassNotFoundException {
        // 获取远程服务的一个网络地址
        InetSocketAddress addr = getService(serviceInterface.getName());
        // 拿到一个代理对象,由这个代理对象通过网络进行实际的服务调用    
        return (T) Proxy.newProxyInstance(serviceInterface.getClassLoader(), new Class<?>[]{serviceInterface},
                new DynProxy(serviceInterface, addr));
    }


    // 动态代理,实现对远程服务的访问
    private static class DynProxy implements InvocationHandler {

        private Class<?> serviceInterface;

        private InetSocketAddress addr;

        public DynProxy(Class<?> serviceInterface, InetSocketAddress addr) {
            this.serviceInterface = serviceInterface;
            this.addr = addr;
        }

        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {

            Socket socket = null;
            ObjectInputStream inputStream = null;
            ObjectOutputStream outputStream = null;
            try {
                socket = new Socket();

                socket.connect(addr);
                outputStream = new ObjectOutputStream(socket.getOutputStream());

                // 方法所在类名接口名
                outputStream.writeUTF(serviceInterface.getName());
                // 方法名
                outputStream.writeUTF(method.getName());
                // 方法入参类型
                outputStream.writeObject(method.getParameterTypes());
                // 方法入参的值
                outputStream.writeObject(args);

                outputStream.flush();

                inputStream = new ObjectInputStream(socket.getInputStream());
                // 接受服务器的输出
                System.out.println(serviceInterface + " remote exec susccess!");
                return inputStream.readObject();
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                if (socket != null) {
                    socket.close();
                }

                if (outputStream != null) {
                    outputStream.close();
                }

                if (inputStream != null) {
                    inputStream.close();
                }
            }

            return null;
        }


    }

    //------------------------------以下和动态获得服务提供者有关
    private static Random r = new Random();

    // 获取远程服务的地址
    private static InetSocketAddress getService(String serviceName) throws IOException, ClassNotFoundException {
        // 获得服务提供者的地址列表
        List<InetSocketAddress> serviceList = getServiceList(serviceName);
        System.out.println("serviceList =" + serviceList.toString());
        InetSocketAddress addr = serviceList.get(r.nextInt(serviceList.size()));
        System.out.println("本次选择了服务器: " + addr);
        return addr;
    }

    private static List<InetSocketAddress> getServiceList(String serviceName) throws IOException,
            ClassNotFoundException {
        Socket socket = null;
        ObjectOutputStream output = null;
        ObjectInputStream input = null;
        List<InetSocketAddress> services = new ArrayList<>();

        try {
            socket = new Socket();
            socket.connect(new InetSocketAddress("127.0.0.1", 9999));

            output = new ObjectOutputStream(socket.getOutputStream());
            // 需要获得服务提供者
            output.writeBoolean(true);
            // 告诉注册中心服务名
            output.writeUTF(serviceName);
            output.flush();

            input = new ObjectInputStream(socket.getInputStream());
            Object object = input.readObject();
            System.out.println("从注册中心读取到的数据是" + object.toString());
            Set<RegisterServiceVo> result = (Set<RegisterServiceVo>) object;
            for (RegisterServiceVo serviceVo : result) {
                // 获取服务提供者
                String host = serviceVo.getHost();
                int port = serviceVo.getPort();
                InetSocketAddress serviceAddr = new InetSocketAddress(host, port);
                services.add(serviceAddr);
            }

            System.out.println("获得服务[" + serviceName + "] 提供者的地址列表[" + services + "],准备调用");

            return services;

        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (socket != null) {
                socket.close();
            }

            if (output != null) {
                output.close();
            }

            if (input != null) {
                input.close();
            }
        }

        return services;
    }
}

NormalBusi

package com.cover.rpc.client.service;

import org.springframework.stereotype.Service;

/**
 * @author xieh
 * @date 2024/02/03 17:46
 */
@Service
public class NormalBusi {
    public void business() {
        System.out.println("其他的业务操作。。。。");
    }
}

RegisterServiceVo

package com.cover.rpc.remote.vo;

import java.io.Serializable;

// 注册中心注册服务的实体类
public class RegisterServiceVo implements Serializable {

    // 服务提供者的ip地址
    private final String host; 
    
    // 服务提供者端口
    private final int port;
    
    public RegisterServiceVo (String host, int port) {
        this.host = host;
        this.port = port;
    }

    public String getHost() {
        return host;
    }

    public int getPort() {
        return port;
    }
}

UserInfo

package com.cover.rpc.remote.vo;

import java.io.Serializable;

/**
 * @author xieh
 * @date 2024/02/03 17:48
 */
public class UserInfo implements Serializable {

    private final String name;

    private final String phone;

    public UserInfo(String name, String phone) {
        this.name = name;
        this.phone = phone;
    }

    public String getName() {
        return name;
    }

    public String getPhone() {
        return phone;
    }
}

SendSms

package com.cover.rpc.remote;

import com.cover.rpc.remote.vo.UserInfo;

/**
 * @author xieh
 * @date 2024/02/03 17:47
 */
// 短信发送接口
public interface SendSms {
    boolean sendMail(UserInfo userInfo);
}

RpcClientApplicationTests

package com.example.rpcclient;

import com.cover.rpc.client.service.NormalBusi;
import com.cover.rpc.remote.SendSms;
import com.cover.rpc.remote.vo.UserInfo;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;

@SpringBootTest
class RpcClientApplicationTests {

    @Test
    void contextLoads() {
    }



    @Autowired
    private NormalBusi normalBusi;

    @Autowired
    private SendSms sendSms;

//    @Test
//    void contextLoads() {
//
//    }

    @Test
    public void rpcTest() {
        long start = System.currentTimeMillis();
        normalBusi.business();

        // 发送邮件
        UserInfo userInfo = new UserInfo("Cover", "181");
        System.out.println("Send mail" + sendSms.sendMail(userInfo));
        System.out.println("共耗时:" + (System.currentTimeMillis() - start));
    }
}

文件结构
手写RPC框架_第4张图片

结果展示

服务注册中心
手写RPC框架_第5张图片

服务消费者
手写RPC框架_第6张图片
服务提供者
手写RPC框架_第7张图片

总结分析

1.我们在当前项目中的序列化框架选择是JDK自带的序列化,注意,你这个时候不能给上面提到的实体类添加唯一的serialId,否则通信过程中则将视为不一样的对象,导致序列化失败,还有就是要注意自己的目录结构,因为如果客户端和服务端中的实体类目录结构不一样,也是不行的,在实际业务中,往往会抽成一个公共的服务来使用,这里为了简洁
2.网络通信模型采用的也是JDK自带的Socket模式,它是阻塞式的,它式无法支撑高并发的网络连接的

如果需要这个项目源码的,可以在下方评论

你可能感兴趣的:(网络IO,Netty,rpc,网络协议,网络,java)