小白谈分布式数据库设计3——外部rpc模块设计

client访问rpc设计

最简单的rpc设计,仅需一小段代码就可以完成。性能的主要消耗是网络通信与序列化和反序列化。这里我们设计一个较为通用的rpc框架,可以完成java下的任意远程调用。
我们选择netty来做网络通信,暂时选择三种可用的java序列化方式hessian,kyro和jdk自己的序列化。

通信数据

根据请求的类名、方法名、方法参数,以及接口名,就可以根据反射和代理来处理请求。

public class MessageRequest implements Serializable {

    private static final long serialVersionUID = 779639215038924077L;
    private String messageId;   //唯一ID
    private String className;   //接口名
    private String methodName;  //方法名
    private Class[] typeParameters;  //参数类型
    private Object[] parametersVal;     //参数

    public String getMessageId() {
        return messageId;
    }

    public void setMessageId(String messageId) {
        this.messageId = messageId;
    }

    public String getClassName() {
        return className;
    }

    public void setClassName(String className) {
        this.className = className;
    }

    public String getMethodName() {
        return methodName;
    }

    public void setMethodName(String methodName) {
        this.methodName = methodName;
    }

    public Class[] getTypeParameters() {
        return typeParameters;
    }

    public void setTypeParameters(Class[] typeParameters) {
        this.typeParameters = typeParameters;
    }

    public Object[] getParameters() {
        return parametersVal;
    }

    public void setParameters(Object[] parametersVal) {
        this.parametersVal = parametersVal;
    }

    @Override
    public String toString() {
        return "MessageRequest{" +
                "messageId='" + messageId + '\'' +
                ", className='" + className + '\'' +
                ", methodName='" + methodName + '\'' +
                ", typeParameters=" + Arrays.toString(typeParameters) +
                ", parametersVal=" + Arrays.toString(parametersVal) +
                '}';
    }
}
public class MessageResponse implements Serializable {
    private static final long serialVersionUID = -4628239730293658445L;
    private String messageId;   //唯一ID
    private String error;       //错误消息
    private Object resultDesc;  //方法调用结果

    public String getMessageId() {
        return messageId;
    }

    public void setMessageId(String messageId) {
        this.messageId = messageId;
    }

    public String getError() {
        return error;
    }

    public void setError(String error) {
        this.error = error;
    }

    public Object getResult() {
        return resultDesc;
    }

    public void setResult(Object resultDesc) {
        this.resultDesc = resultDesc;
    }

    @Override
    public String toString() {
        return "MessageResponse{" +
                "messageId='" + messageId + '\'' +
                ", error='" + error + '\'' +
                ", resultDesc=" + resultDesc +
                '}';
    }
}
public class MessageCallBack {
    private MessageResponse response;
    private Lock lock = new ReentrantLock();
    private Condition finish = lock.newCondition();

    public Object start() throws InterruptedException {
        try {
            lock.lock();
            finish.await(10 * 1000, TimeUnit.MILLISECONDS);
            if (this.response != null) {
                return this.response.getResult();
            } else {
                return null;
            }
        }finally {
            lock.unlock();
        }
    }

    public void over(MessageResponse reponse) {
        try {
            lock.lock();
            finish.signal();
            this.response = reponse;
        } finally {
            lock.unlock();
        }
    }
}

序列化与反序列化

类的encode和decode,设计一个接口

public interface MessageCodecUtil {
    int MESSAGE_LENGTH = 4;

    void encode(final ByteBuf out, final Object message) throws IOException;

    Object decode(byte[] body) throws IOException;
}

encode和decode都需要用到序列化和反序列化,再设计一个接口

public interface RpcSerialize {
    void serialize(OutputStream output, Object object) throws IOException;

    Object deserialize(InputStream input) throws IOException;
}

netty的childHandler需要pipeline添加一系列处理,再设计一个接口

public interface SerializeFrame {
    void select(ChannelPipeline pipeline, Map handlerMap);
}

因为有多种序列化方式可以选择,这里我们使用策略模式,不同的策略对应不同的序列化和encode方式

public class SerializeContext {
    private SerializeFrame serializeFrame;

    public SerializeContext(SerializeFrame serializeFrame) {
        this.serializeFrame = serializeFrame;
    }

    public void setSerializeFrame(SerializeFrame serializeFrame) {
        this.serializeFrame = serializeFrame;
    }

    public void select(ChannelPipeline pipeline, Map handlerMap) {
        serializeFrame.select(pipeline,handlerMap);
    }
}

具体实现就不贴代码了,只是在序列化对象构造时使用倒了对象池化技术,kyro是有自己的KryoPool的,hessian就需要使用common-pool2来池化对象。
另外,服务端还需要定义一个注解,用来确认对应的服务端实现类

@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
public @interface RpcServer {
    Class value();
}

对于这个注解的处理,我们在服务起来的时候模仿spring依赖注入的方式,将所有的service先实例化

public class RpcAnnotationFactory {
    public static Map getBeansWithAnnotation(Class annotation,String packageName){
        Map handlerMap = new ConcurrentHashMap<>();
        String packageDirName = packageName.replace('.', '/');
        Enumeration dirs;
        try {
            dirs = Thread.currentThread().getContextClassLoader().getResources(
                    packageDirName);
            while (dirs.hasMoreElements()){
                URL url = dirs.nextElement();
                String protocol = url.getProtocol();
                if ("file".equals(protocol)) {
                    String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                    findAndAddClassesInPackageByFile(annotation, packageName, filePath,
                            handlerMap);
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return handlerMap;
    }

    private static void findAndAddClassesInPackageByFile(Class annotation, String packageName, String packagePath, Map handlerMap){
        File dir = new File(packagePath);
        if (!dir.exists() || !dir.isDirectory()) {
            return;
        }
        File[] dirfiles = dir.listFiles(new FileFilter(){
            @Override
            public boolean accept(File pathname) {
                return pathname.isDirectory() || pathname.getName().endsWith(".class");
            }
        });

        for (File file : dirfiles){
            if (file.isDirectory()) {
                findAndAddClassesInPackageByFile(annotation,packageName + "."
                                + file.getName(), file.getAbsolutePath(),
                        handlerMap);
            }else{
                String className = file.getName().substring(0, file.getName().length() - 6);
                Object serviceBean = null;
                try {
                    Class clazz = Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className);
                    Annotation anno = clazz.getAnnotation(annotation);
                    if(anno != null){
                        serviceBean = clazz.newInstance();
                        handlerMap.put(anno,serviceBean);
                    }

                } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
                    e.printStackTrace();
                }
            }
        }
    }

}

在MessageRecvExecutor初始化的时候,实现依赖注入.这里还是用了google的ListeningExecutorService来做异步

public class MessageRecvExecutor {
    private static Logger LOG = Logger.getLogger(MessageRecvExecutor.class);
    private String serverAddress;

    private SerializeFrame serializeFrame;

    private final static String DELIMITER = ":";

    private Map handlerMap = new ConcurrentHashMap<>();

    private static ListeningExecutorService threadPoolExecutor;

    public MessageRecvExecutor(String serverAddress, SerializeFrame serializeFrame) {
        this.serverAddress = serverAddress;
        this.serializeFrame = serializeFrame;
    }

    public static void submit(Callable task, final ChannelHandlerContext context, final MessageRequest request,  final MessageResponse response){
        if (threadPoolExecutor == null) {
            synchronized (MessageRecvExecutor.class) {
                if (threadPoolExecutor == null) {
                    threadPoolExecutor = MoreExecutors.listeningDecorator((ThreadPoolExecutor) RpcThreadPool.getExecutor(16, -1));
                }
            }
        }

        ListenableFuture listenableFuture = threadPoolExecutor.submit(task);
        Futures.addCallback(listenableFuture, new FutureCallback() {
            @Override
            public void onSuccess(@Nullable Boolean aBoolean) {
                context.writeAndFlush(response).addListener(new ChannelFutureListener() {
                    @Override
                    public void operationComplete(ChannelFuture future) throws Exception {
                        LOG.info("RPC Server Send message-id respone:" + request.getMessageId());
                    }
                });
            }

            @Override
            public void onFailure(Throwable throwable) {
                throwable.printStackTrace();
            }
        }, threadPoolExecutor);
    }

    public void init(String packageName){
        Map rpcServiceObject = RpcAnnotationFactory.getBeansWithAnnotation(RpcServer.class,packageName);

        if(MapUtils.isNotEmpty(rpcServiceObject)) {
            for(Object serviceBean : rpcServiceObject.values()) {
                String interfaceName = serviceBean.getClass().getAnnotation(RpcServer.class).value().getName();
                handlerMap.put(interfaceName, serviceBean);
            }
        }
    }

    public void afterPropertiesSet(){
        ThreadFactory threadRpcFactory = new NamedThreadFactory("NettyRPC ThreadFactory");
        int parallel = Runtime.getRuntime().availableProcessors() * 2;
        EventLoopGroup boss = new NioEventLoopGroup();
        EventLoopGroup worker = new NioEventLoopGroup(parallel, threadRpcFactory, SelectorProvider.provider());
        ServerBootstrap bootstrap = new ServerBootstrap();
        bootstrap.group(boss, worker).channel(NioServerSocketChannel.class)
                .childHandler(new MessageRecvChannelInitializer(handlerMap).buildRpcSerializeProtocol(new SerializeContext(serializeFrame)))
                .option(ChannelOption.SO_BACKLOG, 128)
                .childOption(ChannelOption.SO_KEEPALIVE, true);
        String[] ipAddr = serverAddress.split(MessageRecvExecutor.DELIMITER);
        if(ipAddr.length == 2){
            String host = ipAddr[0];
            int port = Integer.parseInt(ipAddr[1]);
            try {
                ChannelFuture future = bootstrap.bind(host, port).sync();
                LOG.info(String.format("Netty RPC Server start success!\nip:%s\nport:%d\nprotocol:%s\n\n", host, port, new SerializeContext(serializeFrame).toString()));
                future.channel().closeFuture().sync();
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                worker.shutdownGracefully();
                boss.shutdownGracefully();
            }

        }else {
            LOG.error("Netty RPC Server start fail!\n");
        }
    }

}

server端具体的计算是放在一个Callable里面,使用反射来进行

public class MessageRecvInitializeTask implements Callable {
    private MessageRequest request = null;
    private MessageResponse response = null;
    private Map handlerMap = null;

    public MessageResponse getResponse() {
        return response;
    }

    public MessageRequest getRequest() {
        return request;
    }

    public void setRequest(MessageRequest request) {
        this.request = request;
    }

    MessageRecvInitializeTask(MessageRequest request, MessageResponse response, Map handlerMap) {
        this.request = request;
        this.response = response;
        this.handlerMap = handlerMap;
    }

    @Override
    public Boolean call() throws Exception {
        response.setMessageId(request.getMessageId());
        try {
            Object result = reflect(request);
            response.setResult(result);
            return Boolean.TRUE;
        } catch (Throwable t) {
            response.setError(t.toString());
            t.printStackTrace();
            return Boolean.FALSE;
        }
    }

    private Object reflect(MessageRequest request) throws NoSuchMethodException, IllegalAccessException, InvocationTargetException {
        String className = request.getClassName();
        Object serviceBean = handlerMap.get(className);
        String methodName = request.getMethodName();
        Object[] parameters = request.getParameters();
        return MethodUtils.invokeMethod(serviceBean, methodName, parameters);
    }
}

具体的远程调用方法需要约定相同的接口和实现,然后client端封装MessageSendExecutor来进行调用。client端主要也是使用netty将MessageRequest发送,并将放回的消息序列化为MessageResponse。

public class RpcServerLoader {
    private volatile static RpcServerLoader rpcServerLoader;
    private final static String DELIMITER = ":";

    private final static int parallel = Math.max(2,Runtime.getRuntime().availableProcessors() * 2);

    private EventLoopGroup eventLoopGroup = new NioEventLoopGroup(parallel);

    private static ListeningExecutorService threadPoolExecutor = MoreExecutors.listeningDecorator((ThreadPoolExecutor) RpcThreadPool.getExecutor(16, -1));
    private MessageSendHandler messageSendHandler;
    private Lock lock = new ReentrantLock();
    private Condition connectStatus = lock.newCondition();
    private Condition handlerStatus = lock.newCondition();

    private RpcServerLoader() {
    }

    public static RpcServerLoader getInstance() {
        if (rpcServerLoader == null) {
            synchronized (RpcServerLoader.class) {
                if (rpcServerLoader == null) {
                    rpcServerLoader = new RpcServerLoader();
                }
            }
        }
        return rpcServerLoader;
    }

    public void load(String serverAddress, SerializeFrame serializeFrame) {
        String[] ipAddr = serverAddress.split(RpcServerLoader.DELIMITER);
        if (ipAddr.length == 2) {
            String host = ipAddr[0];
            int port = Integer.parseInt(ipAddr[1]);
            final InetSocketAddress remoteAddr = new InetSocketAddress(host, port);
            ListenableFuture listenableFuture = threadPoolExecutor.submit(new MessageSendInitializeTask(eventLoopGroup, remoteAddr, serializeFrame));
            Futures.addCallback(listenableFuture, new FutureCallback() {
                @Override
                public void onSuccess(@Nullable Boolean aBoolean) {
                    try {
                        lock.lock();
                        if (messageSendHandler == null) handlerStatus.await();
                        if (aBoolean == Boolean.TRUE && messageSendHandler != null) connectStatus.signalAll();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } finally {
                        lock.unlock();
                    }

                }

                @Override
                public void onFailure(Throwable throwable) {
                    throwable.printStackTrace();
                }
            }, threadPoolExecutor);
        }
    }

    public void setMessageSendHandler(MessageSendHandler messageInHandler) {
        try {
            lock.lock();
            this.messageSendHandler = messageInHandler;
            handlerStatus.signal();
        } finally {
            lock.unlock();
        }
    }

    public MessageSendHandler getMessageSendHandler() throws InterruptedException {
        try {
            lock.lock();
            if (messageSendHandler == null) {
                connectStatus.await();
            }
            return messageSendHandler;
        } finally {
            lock.unlock();
        }
    }

    public void unLoad() {
        messageSendHandler.close();
        threadPoolExecutor.shutdown();
        eventLoopGroup.shutdownGracefully();
    }
}

client提交任务Callable

public class MessageSendInitializeTask implements Callable {
    private EventLoopGroup eventLoopGroup = null;
    private InetSocketAddress serverAddress;
    private SerializeFrame serializeFrame;

    public MessageSendInitializeTask(EventLoopGroup eventLoopGroup, InetSocketAddress serverAddress, SerializeFrame serializeFrame) {
        this.eventLoopGroup = eventLoopGroup;
        this.serverAddress = serverAddress;
        this.serializeFrame = serializeFrame;
    }

    @Override
    public Boolean call() throws Exception {
        Bootstrap b = new Bootstrap();
        b.group(eventLoopGroup).channel(NioSocketChannel.class).option(ChannelOption.SO_KEEPALIVE, true);
        b.handler(new MessageSendChannelInitializer().buildRpcSerializeProtocol(new SerializeContext(serializeFrame)));
        ChannelFuture future = b.connect(serverAddress);
        future.addListener(new ChannelFutureListener() {
            @Override
            public void operationComplete(ChannelFuture future) throws Exception {
                if(future.isSuccess()){
                    MessageSendHandler handler = future.channel().pipeline().get(MessageSendHandler.class);
                    RpcServerLoader.getInstance().setMessageSendHandler(handler);
                }
            }
        });

        return Boolean.TRUE;
    }
}

MessageSendHandler主要用来处理返回数据MessageResponse

public class MessageSendHandler extends ChannelInboundHandlerAdapter {
    private ConcurrentHashMap mapCallBack = new ConcurrentHashMap<>();
    private volatile Channel channel;
    private SocketAddress remoteAddr;
    public Channel getChannel() {
        return channel;
    }

    public SocketAddress getRemoteAddr() {
        return remoteAddr;
    }

    @Override
    public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
        super.channelRegistered(ctx);
        this.channel = ctx.channel();
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        super.channelActive(ctx);
        this.remoteAddr = this.channel.remoteAddress();
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        MessageResponse response = (MessageResponse) msg;
        String messageId = response.getMessageId();
        MessageCallBack callBack = mapCallBack.get(messageId);
        if (callBack != null) {
            mapCallBack.remove(messageId);
            callBack.over(response);
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        ctx.close();
    }

    public void close() {
        channel.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
    }

    public MessageCallBack sendRequest(MessageRequest request){
        MessageCallBack callBack = new MessageCallBack();
        mapCallBack.put(request.getMessageId(), callBack);
        channel.writeAndFlush(request);
        return callBack;
    }
}

你可能感兴趣的:(小白谈分布式数据库设计3——外部rpc模块设计)